Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reformatted dataprofiler/labelers using black 22.3.0. #526

Merged
merged 2 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions dataprofiler/labelers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,14 @@
# import data labelers
# import models
from .base_data_labeler import BaseDataLabeler, TrainableDataLabeler
from .data_labelers import DataLabeler, StructuredDataLabeler, \
UnstructuredDataLabeler
from .data_labelers import DataLabeler, StructuredDataLabeler, UnstructuredDataLabeler

# import data processors
from .data_processing import CharPostprocessor, CharPreprocessor, \
DirectPassPreprocessor, RegexPostProcessor, StructCharPostprocessor, \
StructCharPreprocessor
from .data_processing import (
CharPostprocessor,
CharPreprocessor,
DirectPassPreprocessor,
RegexPostProcessor,
StructCharPostprocessor,
StructCharPreprocessor,
)
355 changes: 202 additions & 153 deletions dataprofiler/labelers/base_data_labeler.py

Large diffs are not rendered by default.

91 changes: 56 additions & 35 deletions dataprofiler/labelers/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ def __eq__(self, other):
:return: Whether or not self and other are equal
:rtype: bool
"""
if type(self) != type(other) \
or self._parameters != other._parameters \
or self._label_mapping != other._label_mapping:
if (
type(self) != type(other)
or self._parameters != other._parameters
or self._label_mapping != other._label_mapping
):
return False
return True

Expand Down Expand Up @@ -88,8 +90,12 @@ def labels(self):
Retrieves the label
:return: list of labels
"""
return [v for k, v in sorted(self.reverse_label_mapping.items(),
key=lambda item: item[0])]
return [
v
for k, v in sorted(
self.reverse_label_mapping.items(), key=lambda item: item[0]
)
]

@staticmethod
def _convert_labels_to_label_mapping(labels, requires_zero_mapping):
Expand All @@ -108,8 +114,7 @@ def _convert_labels_to_label_mapping(labels, requires_zero_mapping):

# if list
start_index = 0 if requires_zero_mapping else 1
return dict(zip(labels, list(
range(start_index, start_index + len(labels)))))
return dict(zip(labels, list(range(start_index, start_index + len(labels)))))

@property
def num_labels(self):
Expand All @@ -118,10 +123,10 @@ def num_labels(self):
@classmethod
def get_class(cls, class_name):

# Import possible internal models
# Import possible internal models
from .character_level_cnn_model import CharacterLevelCnnModel
from .regex_model import RegexModel

return cls._BaseModel__subclasses.get(class_name.lower(), None)

def get_parameters(self, param_list=None):
Expand All @@ -133,18 +138,21 @@ def get_parameters(self, param_list=None):
"""
if param_list is None:
parameters = copy.deepcopy(self._parameters)
parameters['label_mapping'] = copy.deepcopy(self._label_mapping)
parameters["label_mapping"] = copy.deepcopy(self._label_mapping)
return parameters

param_dict = {}
for param in param_list:
if param in self._parameters:
param_dict[param] = self._parameters.get(param)
elif param == 'label_mapping':
param_dict['label_mapping'] = self._label_mapping
elif param == "label_mapping":
param_dict["label_mapping"] = self._label_mapping
else:
raise ValueError('`{}` does not exist as a parameter in {}.'.
format(param, self.__class__.__name__))
raise ValueError(
"`{}` does not exist as a parameter in {}.".format(
param, self.__class__.__name__
)
)
return copy.deepcopy(param_dict)

def set_params(self, **kwargs):
Expand All @@ -169,23 +177,25 @@ def add_label(self, label, same_as=None):
"""
# validate label
if not label or not isinstance(label, str):
raise TypeError('`label` must be a str.')
raise TypeError("`label` must be a str.")
elif label in self._label_mapping:
warnings.warn('The label, `{}`, already exists in the label '
'mapping.'.format(label))
warnings.warn(
"The label, `{}`, already exists in the label " "mapping.".format(label)
)
return

# validate same_as
if same_as and not isinstance(same_as, str):
raise TypeError('`same_as` must be a str.')
raise TypeError("`same_as` must be a str.")
elif same_as and same_as not in self._label_mapping:
raise ValueError('`same_as` value: {}, did not exist in the '
'label_mapping.'.format(same_as))
raise ValueError(
"`same_as` value: {}, did not exist in the "
"label_mapping.".format(same_as)
)

# add label to label_mapping
max_label_ind = max(self._label_mapping.values())
self._label_mapping[label] = self._label_mapping.get(same_as,
max_label_ind + 1)
self._label_mapping[label] = self._label_mapping.get(same_as, max_label_ind + 1)

def set_label_mapping(self, label_mapping):
"""
Expand All @@ -197,10 +207,13 @@ def set_label_mapping(self, label_mapping):
:return: None
"""
if not isinstance(label_mapping, (list, dict)) or not label_mapping:
raise TypeError("Labels must either be a non-empty encoding dict "
"which maps labels to index encodings or a list.")
raise TypeError(
"Labels must either be a non-empty encoding dict "
"which maps labels to index encodings or a list."
)
label_mapping = self._convert_labels_to_label_mapping(
label_mapping, self.requires_zero_mapping)
label_mapping, self.requires_zero_mapping
)
self._label_mapping = copy.deepcopy(label_mapping)

@abc.abstractmethod
Expand Down Expand Up @@ -230,12 +243,15 @@ def help(cls):
:return: None
"""
param_docs = inspect.getdoc(cls._validate_parameters)
param_start_ind = param_docs.find('parameters:\n') + 12
param_end_ind = param_docs.find(':type parameters:') - 1

help_str = cls.__name__ + "\n\n" + \
"Parameters:\n" + \
param_docs[param_start_ind:param_end_ind]
param_start_ind = param_docs.find("parameters:\n") + 12
param_end_ind = param_docs.find(":type parameters:") - 1

help_str = (
cls.__name__
+ "\n\n"
+ "Parameters:\n"
+ param_docs[param_start_ind:param_end_ind]
)
print(help_str)

@abc.abstractmethod
Expand Down Expand Up @@ -307,10 +323,16 @@ def save_to_disk(self, dirpath):


class BaseTrainableModel(BaseModel, metaclass=abc.ABCMeta):

@abc.abstractmethod
def fit(self, train_data, val_data, batch_size=32, epochs=1,
label_mapping=None, reset_weights=False):
def fit(
self,
train_data,
val_data,
batch_size=32,
epochs=1,
label_mapping=None,
reset_weights=False,
):
"""
Train the current model with the training data and validation data
:param train_data: Training data used to train model
Expand All @@ -329,4 +351,3 @@ def fit(self, train_data, val_data, batch_size=32, epochs=1,
:return: None
"""
raise NotImplementedError()

Loading