Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static typing to labeler utils #668

Merged
merged 5 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions dataprofiler/labelers/classification_report_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Contains functions for classification."""
import warnings
from typing import Dict, List, Optional, Set, Tuple, Union, cast

import numpy as np
import sklearn.metrics._classification


def convert_confusion_matrix_to_MCM(conf_matrix):
def convert_confusion_matrix_to_MCM(conf_matrix: Union[List, np.ndarray]):
"""
Convert a confusion matrix into the MCM format.

Expand Down Expand Up @@ -52,14 +53,14 @@ def convert_confusion_matrix_to_MCM(conf_matrix):


def precision_recall_fscore_support(
MCM,
beta=1.0,
labels=None,
pos_label=1,
average=None,
warn_for=("precision", "recall", "f-score"),
sample_weight=None,
):
MCM: np.ndarray,
beta: float = 1.0,
labels: np.ndarray = None,
pos_label: Union[str, int] = 1,
average: str = None,
warn_for: Union[Tuple[str, ...], Set[str]] = ("precision", "recall", "f-score"),
sample_weight: np.ndarray = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
"""
Perform same functionality as recision_recall_fscore_support function.

Expand Down Expand Up @@ -222,13 +223,13 @@ def precision_recall_fscore_support(


def classification_report(
conf_matrix,
labels=None,
target_names=None,
sample_weight=None,
digits=2,
output_dict=False,
):
conf_matrix: np.ndarray,
labels: Union[List, np.ndarray] = None,
target_names: List[str] = None,
sample_weight: np.ndarray = None,
digits: int = 2,
output_dict: bool = False,
) -> Union[str, Dict]:
"""
Build a text report showing the main classification metrics.

Expand Down Expand Up @@ -342,15 +343,15 @@ def classification_report(
p, r, f1, s = precision_recall_fscore_support(
MCM, labels=labels, average=None, sample_weight=sample_weight
)
rows = zip(target_names, p, r, f1, s)
rows = zip(target_names, p, r, f1, cast(np.ndarray, s))

if y_type.startswith("multilabel"):
average_options = ("micro", "macro", "weighted", "samples")
average_options: Tuple[str, ...] = ("micro", "macro", "weighted", "samples")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the ... do we not want to enforce a len of four? so something like Tuple[str, str, str, str]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next line is a tuple with length 3

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To specify a variable-length tuple of homogeneous type, use literal ellipsis, e.g. Tuple[int, ...]. A plain [Tuple](https://docs.python.org/3/library/typing.html#typing.Tuple) is equivalent to Tuple[Any, ...], and in turn to [tuple](https://docs.python.org/3/library/stdtypes.html#tuple).

else:
average_options = ("micro", "macro", "weighted")

if output_dict:
report_dict = {label[0]: label[1:] for label in rows}
report_dict: Dict = {label[0]: label[1:] for label in rows}
for label, scores in report_dict.items():
report_dict[label] = dict(zip(headers, [i.item() for i in scores]))
else:
Expand Down Expand Up @@ -379,7 +380,7 @@ def classification_report(
avg_p, avg_r, avg_f1, _ = precision_recall_fscore_support(
MCM, labels=labels, average=average, sample_weight=sample_weight
)
avg = [avg_p, avg_r, avg_f1, np.sum(s)]
avg = [avg_p, avg_r, avg_f1, np.sum(cast(np.ndarray, s))]

if output_dict:
report_dict[line_heading] = dict(zip(headers, [i.item() for i in avg]))
Expand Down
100 changes: 62 additions & 38 deletions dataprofiler/labelers/labeler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, cast

import numpy as np
import scipy
Expand All @@ -16,7 +17,7 @@
logger = dp_logging.get_child_logger(__name__)


def f1_report_dict_to_str(f1_report, label_names):
def f1_report_dict_to_str(f1_report: Dict, label_names: List[str]) -> str:
"""
Return the report string from the f1_report dict.

Expand Down Expand Up @@ -75,14 +76,14 @@ class 1 1.00 0.67 0.80 3


def evaluate_accuracy(
predicted_entities_in_index,
true_entities_in_index,
num_labels,
entity_rev_dict,
verbose=True,
omitted_labels=("PAD", "UNKNOWN"),
confusion_matrix_file=None,
):
predicted_entities_in_index: List[List[int]],
true_entities_in_index: List[List[int]],
num_labels: int,
entity_rev_dict: Dict,
verbose: bool = True,
omitted_labels: Tuple[str, ...] = ("PAD", "UNKNOWN"),
confusion_matrix_file: str = None,
) -> Tuple[float, Dict]:
"""
Evaluate accuracy from comparing predicted labels with true labels.

Expand Down Expand Up @@ -121,7 +122,7 @@ def evaluate_accuracy(
for i, true_labels_row in enumerate(true_entities_in_index):
true_labels_padded[i][: len(true_labels_row)] = true_labels_row

true_labels_flatten = np.hstack(true_labels_padded)
true_labels_flatten = np.hstack(true_labels_padded) # type: ignore
predicted_labels_flatten = np.hstack(predicted_entities_in_index)

if entity_rev_dict:
Expand Down Expand Up @@ -164,8 +165,11 @@ def evaluate_accuracy(

conf_mat_pd.to_csv(confusion_matrix_file)

f1_report = classification_report(
conf_mat, labels=label_indexes, target_names=label_names, output_dict=True
f1_report: Dict = cast(
Dict,
classification_report(
conf_mat, labels=label_indexes, target_names=label_names, output_dict=True
),
)

# adjust macro average to be updated only on positive support labels
Expand All @@ -183,25 +187,33 @@ def evaluate_accuracy(
if not num_labels_with_positive_support:
f1_report["macro avg"][metric] = np.nan
else:
f1_report["macro avg"][metric] *= (
float(len(label_names)) / num_labels_with_positive_support
)
if not label_names:
f1_report["macro avg"][metric] = 0
else:
f1_report["macro avg"][metric] *= (
float(len(label_names)) / num_labels_with_positive_support
)

if "macro avg" in f1_report:
f1 = f1_report["macro avg"]["f1-score"] # this is micro for the report
f1: float = f1_report["macro avg"]["f1-score"] # this is micro for the report
else:
# this is the only remaining option for the report
f1 = f1_report["accuracy"]

if verbose:
if not label_names:
label_names = [""]

f1_report_str = f1_report_dict_to_str(f1_report, label_names)
logger.info(f"(After removing non-entity tokens)\n{f1_report_str}")
logger.info(f"F1 Score: {f1}")

return f1, f1_report


def get_tf_layer_index_from_name(model, layer_name):
def get_tf_layer_index_from_name(
model: tf.keras.Model, layer_name: str
) -> Optional[int]:
"""
Return the index of the layer given the layer name within a tf model.

Expand All @@ -212,15 +224,16 @@ def get_tf_layer_index_from_name(model, layer_name):
for idx, layer in enumerate(model.layers):
if layer.name == layer_name:
return idx
return None


def hide_tf_logger_warnings():
def hide_tf_logger_warnings() -> None:
"""Filter out a set of warnings from the tf logger."""

class NoV1ResourceMessageFilter(logging.Filter):
"""Removes TF2 warning for using TF1 model which has resources."""

def filter(self, record):
def filter(self, record: logging.LogRecord) -> bool:
"""Remove warning."""
msg = (
"is a problem, consider rebuilding the SavedModel after "
Expand All @@ -232,15 +245,17 @@ def filter(self, record):
tf_logger.addFilter(NoV1ResourceMessageFilter())


def protected_register_keras_serializable(package="Custom", name=None):
def protected_register_keras_serializable(
package: str = "Custom", name: str = None
) -> Callable:
"""
Protect against already registered keras serializable layers.

Ensures that if it was already registered, it will not try to
register it again.
"""

def decorator(arg):
def decorator(arg: Any) -> Any:
"""Protect against double registration of a keras layer."""
class_name = name if name is not None else arg.__name__
registered_name = package + ">" + class_name
Expand Down Expand Up @@ -300,14 +315,14 @@ class FBetaScore(tf.keras.metrics.Metric):
# Modification: remove the run-time type checking for functions
def __init__(
self,
num_classes,
average=None,
beta=1.0,
threshold=None,
name="fbeta_score",
dtype=None,
**kwargs,
):
num_classes: int,
average: str = None,
beta: float = 1.0,
threshold: float = None,
name: str = "fbeta_score",
dtype: str = None,
**kwargs: Any,
) -> None:
"""Initialize FBetaScore class."""
super().__init__(name=name, dtype=dtype)

Expand Down Expand Up @@ -340,7 +355,7 @@ def __init__(
self.axis = 0
self.init_shape = [self.num_classes]

def _zero_wt_init(name):
def _zero_wt_init(name: str) -> tf.Variable:
return self.add_weight(
name, shape=self.init_shape, initializer="zeros", dtype=self.dtype
)
Expand All @@ -350,7 +365,9 @@ def _zero_wt_init(name):
self.false_negatives = _zero_wt_init("false_negatives")
self.weights_intermediate = _zero_wt_init("weights_intermediate")

def update_state(self, y_true, y_pred, sample_weight=None):
def update_state(
self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor = None
) -> None:
"""Update state."""
if self.threshold is None:
threshold = tf.reduce_max(y_pred, axis=-1, keepdims=True)
Expand All @@ -363,7 +380,9 @@ def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.cast(y_true, self.dtype)
y_pred = tf.cast(y_pred, self.dtype)

def _weighted_sum(val, sample_weight):
def _weighted_sum(
val: tf.Tensor, sample_weight: Optional[tf.Tensor]
) -> tf.Tensor:
if sample_weight is not None:
val = tf.math.multiply(val, tf.expand_dims(sample_weight, 1))
return tf.reduce_sum(val, axis=self.axis)
Expand All @@ -377,7 +396,7 @@ def _weighted_sum(val, sample_weight):
)
self.weights_intermediate.assign_add(_weighted_sum(y_true, sample_weight))

def result(self):
def result(self) -> tf.Tensor:
"""Return f1 score."""
precision = tf.math.divide_no_nan(
self.true_positives, self.true_positives + self.false_positives
Expand All @@ -402,7 +421,7 @@ def result(self):

return f1_score

def get_config(self):
def get_config(self) -> Dict:
"""Return the serializable config of the metric."""
config = {
"num_classes": self.num_classes,
Expand All @@ -414,7 +433,7 @@ def get_config(self):
base_config = super().get_config()
return {**base_config, **config}

def reset_states(self):
def reset_states(self) -> None:
"""Reset states."""
reset_value = tf.zeros(self.init_shape, dtype=self.dtype)
tf.keras.backend.batch_set_value([(v, reset_value) for v in self.variables])
Expand Down Expand Up @@ -463,12 +482,17 @@ class F1Score(FBetaScore):

# Modification: remove the run-time type checking for functions
def __init__(
self, num_classes, average=None, threshold=None, name="f1_score", dtype=None
):
self,
num_classes: int,
average: str = None,
threshold: float = None,
name: str = "f1_score",
dtype: str = None,
) -> None:
"""Initialize F1Score object."""
super().__init__(num_classes, average, 1.0, threshold, name=name, dtype=dtype)

def get_config(self):
def get_config(self) -> Dict:
"""Get configuration."""
base_config = super().get_config()
del base_config["beta"]
Expand Down
9 changes: 5 additions & 4 deletions dataprofiler/labelers/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Contains functions for checking for installations/dependencies."""
import sys
import warnings
from typing import Any, Callable, List


def warn_missing_module(labeler_function, module_name):
def warn_missing_module(labeler_function: str, module_name: str) -> None:
"""
Return a warning if a given graph module doesn't exist.

Expand All @@ -21,7 +22,7 @@ def warn_missing_module(labeler_function, module_name):
warnings.warn(warning_msg, RuntimeWarning, stacklevel=3)


def require_module(names):
def require_module(names: List[str]) -> Callable:
"""
Check if a set of modules exists in sys.modules prior to running function.

Expand All @@ -32,8 +33,8 @@ def require_module(names):
:type names: list[str]
"""

def check_module(f):
def new_f(*args, **kwds):
def check_module(f) -> Callable:
def new_f(*args: Any, **kwds: Any) -> Any:
for module_name in names:
if module_name not in sys.modules.keys():
# attempt to reload if missing
Expand Down