Skip to content

Commit

Permalink
Add static typing to labelers/data_processing.py (#673)
Browse files Browse the repository at this point in the history
* Fix typing

* Fix tests

* Add static typing

* More static typing

* Add static typing

* Fix issue

* Update dataprofiler/tests/labelers/test_data_processing.py

Co-authored-by: Taylor Turner <[email protected]>

* Fix typing

* Improve typing with generics

* Fix typing

* Fix errors

* Add base process again

Co-authored-by: Taylor Turner <[email protected]>
  • Loading branch information
tonywu315 and taylorfturner authored Oct 7, 2022
1 parent f30fc75 commit 2e35ced
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 163 deletions.
31 changes: 18 additions & 13 deletions dataprofiler/labelers/base_data_labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import sys
import warnings
from typing import Dict, List, Optional, Union, cast
from typing import Dict, List, Optional, Type, Union, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -135,7 +135,7 @@ def preprocessor(self) -> Optional[data_processing.BaseDataPreprocessor]:
return self._preprocessor

@property
def model(self) -> Optional[BaseModel]:
def model(self) -> BaseModel:
"""
Retrieve the data labeler model.
Expand Down Expand Up @@ -511,32 +511,35 @@ def _load_parameters(dirpath: str, load_options: Dict = None) -> Dict[str, Dict]
params["postprocessor"]["class"] = load_options.get("postprocessor_class")
return params

def _load_model(self, model_class: Union[BaseModel, str], dirpath: str) -> None:
def _load_model(
self, model_class: Optional[Union[Type[BaseModel], str]], dirpath: str
) -> None:
"""
Load the data labeler model.
Can be done by either using a provided model class or
retrieving a registered data labeler model.
:param model_class: class of model being loaded
:type model_class: Union[BaseModel, str]
:type model_class: Union[Type[BaseModel], str]
:param dirpath: directory where the saved DataLabeler model exists.
:type dirpath: str
:return: None
"""
if isinstance(model_class, str):
model_class: BaseModel = BaseModel.get_class(model_class) # type: ignore
model_class = BaseModel.get_class(model_class)

if not model_class:
raise ValueError(
"`model_class`, {}, was not set in load_options "
"and could not be found as a registered model "
"class in BaseModel.".format(str(model_class))
)
self.set_model(cast(BaseModel, model_class).load_from_disk(dirpath))
self.set_model(model_class.load_from_disk(dirpath))

def _load_preprocessor(
self,
processor_class: Union[data_processing.BaseDataProcessor, str],
processor_class: Optional[Union[Type[data_processing.BaseDataProcessor], str]],
dirpath: str,
) -> None:
"""
Expand All @@ -559,21 +562,22 @@ def _load_preprocessor(
"class in BaseDataProcessor.".format(str(processor_class))
)
self.set_preprocessor(
cast(data_processing.BaseDataProcessor, processor_class).load_from_disk(
dirpath
cast(
data_processing.BaseDataPreprocessor,
processor_class.load_from_disk(dirpath),
)
)

def _load_postprocessor(
self,
processor_class: Union[data_processing.BaseDataProcessor, str],
processor_class: Optional[Union[Type[data_processing.BaseDataProcessor], str]],
dirpath: str,
) -> None:
"""
Load the postprocessor for the data labeler.
:param processor_class: class of model being loaded
:type processor_class: Union[data_processing.BaseDataProcessor, str]
:type processor_class: Union[Type[data_processing.BaseDataPostprocessor], str]
:param dirpath: directory where the saved DataLabeler model exists.
:type dirpath: str
:return: None
Expand All @@ -591,8 +595,9 @@ def _load_postprocessor(
)
)
self.set_postprocessor(
cast(data_processing.BaseDataProcessor, processor_class).load_from_disk(
dirpath
cast(
data_processing.BaseDataPostprocessor,
processor_class.load_from_disk(dirpath),
)
)

Expand Down
7 changes: 4 additions & 3 deletions dataprofiler/labelers/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
import inspect
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast

from dataprofiler._typing import DataArray

Expand Down Expand Up @@ -59,7 +59,7 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
cls._register_subclass()

def __eq__(self, other: BaseModel) -> bool: # type: ignore
def __eq__(self, other: object) -> bool:
"""
Check if two models are equal with one another.
Expand All @@ -74,6 +74,7 @@ def __eq__(self, other: BaseModel) -> bool: # type: ignore
"""
if (
type(self) != type(other)
or not isinstance(other, BaseModel)
or self._parameters != other._parameters
or self._label_mapping != other._label_mapping
):
Expand Down Expand Up @@ -270,7 +271,7 @@ def help(cls) -> None:
:return: None
"""
param_docs: str = inspect.getdoc(cls._validate_parameters) # type: ignore
param_docs = cast(str, inspect.getdoc(cls._validate_parameters))
param_start_ind = param_docs.find("parameters:\n") + 12
param_end_ind = param_docs.find(":type parameters:") - 1

Expand Down
Loading

0 comments on commit 2e35ced

Please sign in to comment.