Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3] [contrib/metrics] setup typing in contrib part of the library #1363

Merged
merged 4 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions ignite/contrib/metrics/average_precision.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Callable

import torch

from ignite.metrics import EpochMetric


def average_precision_compute_fn(y_preds, y_targets):
def average_precision_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor):
try:
from sklearn.metrics import average_precision_score
except ImportError:
Expand All @@ -22,7 +26,7 @@ class AveragePrecision(EpochMetric):
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
check_compute_fn (bool): Optional default False. If True, `average_precision_score
check_compute_fn (bool): Default False. If True, `average_precision_score
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html
#sklearn.metrics.average_precision_score>`_ is run on the first batch of data to ensure there are
no issues. User will be warned in case there are any issues computing the function.
Expand All @@ -41,7 +45,7 @@ def activated_output_transform(output):

"""

def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False):
def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False):
super(AveragePrecision, self).__init__(
average_precision_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
)
9 changes: 5 additions & 4 deletions ignite/contrib/metrics/gpu_info.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# -*- coding: utf-8 -*-
import warnings
from typing import Tuple, Union

import torch

from ignite.engine import Events
from ignite.engine import Engine, EventEnum, Events
from ignite.metrics import Metric


Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(self):
def reset(self):
pass

def update(self, output):
def update(self, output: Tuple[torch.Tensor, torch.Tensor]):
pass

def compute(self):
Expand All @@ -64,7 +65,7 @@ def compute(self):
return []
return data["gpu"]

def completed(self, engine, name):
def completed(self, engine: Engine, name: str):
data = self.compute()
if len(data) < 1:
warnings.warn("No GPU information available")
Expand Down Expand Up @@ -103,5 +104,5 @@ def completed(self, engine, name):
# Do not set GPU utilization information
pass

def attach(self, engine, name="gpu", event_name=Events.ITERATION_COMPLETED):
def attach(self, engine: Engine, name: str = "gpu", event_name: Union[str, EventEnum] = Events.ITERATION_COMPLETED):
engine.add_event_handler(event_name, self.completed, name)
10 changes: 7 additions & 3 deletions ignite/contrib/metrics/precision_recall_curve.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Callable

import torch

from ignite.metrics import EpochMetric


def precision_recall_curve_compute_fn(y_preds, y_targets):
def precision_recall_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor):
try:
from sklearn.metrics import precision_recall_curve
except ImportError:
Expand All @@ -23,7 +27,7 @@ class PrecisionRecallCurve(EpochMetric):
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
check_compute_fn (bool): Optional default False. If True, `precision_recall_curve
check_compute_fn (bool): Default False. If True, `precision_recall_curve
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html
#sklearn.metrics.precision_recall_curve>`_ is run on the first batch of data to ensure there are
no issues. User will be warned in case there are any issues computing the function.
Expand All @@ -42,7 +46,7 @@ def activated_output_transform(output):

"""

def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False):
def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False):
super(PrecisionRecallCurve, self).__init__(
precision_recall_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
)
18 changes: 10 additions & 8 deletions ignite/contrib/metrics/regression/_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from abc import abstractmethod
from typing import Callable, Union
from typing import Callable, Tuple

import torch

from ignite.metrics import EpochMetric, Metric
from ignite.metrics.metric import reinit__is_reduced


def _check_output_shapes(output):
def _check_output_shapes(output: Tuple[torch.Tensor, torch.Tensor]):
y_pred, y = output
if y_pred.shape != y.shape:
raise ValueError("Input data shapes should be the same, but given {} and {}".format(y_pred.shape, y.shape))
Expand All @@ -21,7 +21,7 @@ def _check_output_shapes(output):
raise ValueError("Input y should have shape (N,) or (N, 1), but given {}".format(y.shape))


def _check_output_types(output):
def _check_output_types(output: Tuple[torch.Tensor, torch.Tensor]):
y_pred, y = output
if y_pred.dtype not in (torch.float16, torch.float32, torch.float64):
raise TypeError("Input y_pred dtype should be float 16, 32 or 64, but given {}".format(y_pred.dtype))
Expand All @@ -36,7 +36,7 @@ class _BaseRegression(Metric):
# method `_update`.

@reinit__is_reduced
def update(self, output):
def update(self, output: Tuple[torch.Tensor, torch.Tensor]):
_check_output_shapes(output)
_check_output_types(output)
y_pred, y = output[0].detach(), output[1].detach()
Expand All @@ -50,7 +50,7 @@ def update(self, output):
self._update((y_pred, y))

@abstractmethod
def _update(self, output):
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
pass


Expand All @@ -59,14 +59,16 @@ class _BaseRegressionEpoch(EpochMetric):
# `update` method check the shapes and call internal overloaded method `_update`.
# Class internally stores complete history of predictions and targets of type float32.

def __init__(self, compute_fn, output_transform=lambda x: x, check_compute_fn: bool = True):
def __init__(
self, compute_fn: Callable, output_transform: Callable = lambda x: x, check_compute_fn: bool = True,
):
super(_BaseRegressionEpoch, self).__init__(
compute_fn=compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
)

def _check_type(self, output):
def _check_type(self, output: Tuple[torch.Tensor, torch.Tensor]):
_check_output_types(output)
super(_BaseRegressionEpoch, self)._check_type(output)

def _check_shape(self, output):
def _check_shape(self, output: Tuple[torch.Tensor, torch.Tensor]):
_check_output_shapes(output)
4 changes: 2 additions & 2 deletions ignite/contrib/metrics/regression/canberra_metric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Union
from typing import Callable, Tuple, Union

import torch

Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(
def reset(self):
self._sum_of_errors = torch.tensor(0.0, device=self._device)

def _update(self, output):
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
y_pred, y = output
errors = torch.abs(y - y_pred) / (torch.abs(y_pred) + torch.abs(y))
self._sum_of_errors += torch.sum(errors).to(self._device)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import torch

from ignite.contrib.metrics.regression._base import _BaseRegression
Expand All @@ -24,7 +26,7 @@ def reset(self):
self._sum_of_errors = 0.0
self._num_examples = 0

def _update(self, output):
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
y_pred, y = output
errors = 2 * torch.abs(y.view_as(y_pred) - y_pred) / (torch.abs(y_pred) + torch.abs(y.view_as(y_pred)))
self._sum_of_errors += torch.sum(errors).item()
Expand Down
4 changes: 3 additions & 1 deletion ignite/contrib/metrics/regression/fractional_bias.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import torch

from ignite.contrib.metrics.regression._base import _BaseRegression
Expand Down Expand Up @@ -25,7 +27,7 @@ def reset(self):
self._sum_of_errors = 0.0
self._num_examples = 0

def _update(self, output):
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
y_pred, y = output
errors = 2 * (y.view_as(y_pred) - y_pred) / (y_pred + y.view_as(y_pred))
self._sum_of_errors += torch.sum(errors).item()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import torch

from ignite.contrib.metrics.regression._base import _BaseRegression
Expand All @@ -24,7 +26,7 @@ def reset(self):
self._sum_of_errors = 0.0
self._num_examples = 0

def _update(self, output):
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
y_pred, y = output
errors = torch.log(torch.abs(y.view_as(y_pred) - y_pred))
self._sum_of_errors += torch.sum(errors)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Tuple

import torch

from ignite.contrib.metrics.regression._base import _BaseRegression
from ignite.exceptions import NotComputableError


class GeometricMeanRelativeAbsoluteError(_BaseRegression):
Expand All @@ -26,7 +29,7 @@ def reset(self):
self._num_examples = 0
self._sum_of_errors = 0.0

def _update(self, output):
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
y_pred, y = output
self._sum_y += y.sum()
self._num_examples += y.shape[0]
Expand Down
4 changes: 2 additions & 2 deletions ignite/contrib/metrics/regression/manhattan_distance.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Union
from typing import Callable, Tuple, Union

import torch

Expand Down Expand Up @@ -33,7 +33,7 @@ def __init__(
def reset(self):
self._sum_of_errors = torch.tensor(0.0, device=self._device)

def _update(self, output):
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
y_pred, y = output
errors = torch.abs(y - y_pred)
self._sum_of_errors += torch.sum(errors).to(self._device)
Expand Down
4 changes: 3 additions & 1 deletion ignite/contrib/metrics/regression/maximum_absolute_error.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import torch

from ignite.contrib.metrics.regression._base import _BaseRegression
Expand All @@ -24,7 +26,7 @@ class MaximumAbsoluteError(_BaseRegression):
def reset(self):
self._max_of_absolute_errors = -1

def _update(self, output):
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
y_pred, y = output
mae = torch.abs(y_pred - y.view_as(y_pred)).max().item()
if self._max_of_absolute_errors < mae:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import torch

from ignite.contrib.metrics.regression._base import _BaseRegression
Expand Down Expand Up @@ -25,7 +27,7 @@ def reset(self):
self._sum_of_absolute_relative_errors = 0.0
self._num_samples = 0

def _update(self, output):
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
y_pred, y = output
if (y == 0).any():
raise NotComputableError("The ground truth has 0.")
Expand Down
4 changes: 3 additions & 1 deletion ignite/contrib/metrics/regression/mean_error.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import torch

from ignite.contrib.metrics.regression._base import _BaseRegression
Expand Down Expand Up @@ -25,7 +27,7 @@ def reset(self):
self._sum_of_errors = 0.0
self._num_examples = 0

def _update(self, output):
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
y_pred, y = output
errors = y.view_as(y_pred) - y_pred
self._sum_of_errors += torch.sum(errors).item()
Expand Down
4 changes: 3 additions & 1 deletion ignite/contrib/metrics/regression/mean_normalized_bias.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import torch

from ignite.contrib.metrics.regression._base import _BaseRegression
Expand Down Expand Up @@ -25,7 +27,7 @@ def reset(self):
self._sum_of_errors = 0.0
self._num_examples = 0

def _update(self, output):
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
y_pred, y = output

if (y == 0).any():
Expand Down
6 changes: 4 additions & 2 deletions ignite/contrib/metrics/regression/median_absolute_error.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Callable

import torch

from ignite.contrib.metrics.regression._base import _BaseRegressionEpoch


def median_absolute_error_compute_fn(y_pred, y):
def median_absolute_error_compute_fn(y_pred: torch.Tensor, y: torch.Tensor):
e = torch.abs(y.view_as(y_pred) - y_pred)
return torch.median(e).item()

Expand Down Expand Up @@ -31,5 +33,5 @@ class MedianAbsoluteError(_BaseRegressionEpoch):

"""

def __init__(self, output_transform=lambda x: x):
def __init__(self, output_transform: Callable = lambda x: x):
super(MedianAbsoluteError, self).__init__(median_absolute_error_compute_fn, output_transform)
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Callable

import torch

from ignite.contrib.metrics.regression._base import _BaseRegressionEpoch


def median_absolute_percentage_error_compute_fn(y_pred, y):
def median_absolute_percentage_error_compute_fn(y_pred: torch.Tensor, y: torch.Tensor):
e = torch.abs(y.view_as(y_pred) - y_pred) / torch.abs(y.view_as(y_pred))
return 100.0 * torch.median(e).item()

Expand Down Expand Up @@ -31,7 +33,7 @@ class MedianAbsolutePercentageError(_BaseRegressionEpoch):

"""

def __init__(self, output_transform=lambda x: x):
def __init__(self, output_transform: Callable = lambda x: x):
super(MedianAbsolutePercentageError, self).__init__(
median_absolute_percentage_error_compute_fn, output_transform
)
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Callable

import torch

from ignite.contrib.metrics.regression._base import _BaseRegressionEpoch


def median_relative_absolute_error_compute_fn(y_pred, y):
def median_relative_absolute_error_compute_fn(y_pred: torch.Tensor, y: torch.Tensor):
e = torch.abs(y.view_as(y_pred) - y_pred) / torch.abs(y.view_as(y_pred) - torch.mean(y))
return torch.median(e).item()

Expand Down Expand Up @@ -31,5 +33,5 @@ class MedianRelativeAbsoluteError(_BaseRegressionEpoch):

"""

def __init__(self, output_transform=lambda x: x):
def __init__(self, output_transform: Callable = lambda x: x):
super(MedianRelativeAbsoluteError, self).__init__(median_relative_absolute_error_compute_fn, output_transform)
Loading