Skip to content

Commit

Permalink
enable preprocessing and remove is_small_preprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
ravinkohli committed Jan 24, 2022
1 parent 1c1ff8a commit 23d5fd4
Show file tree
Hide file tree
Showing 22 changed files with 155 additions and 180 deletions.
116 changes: 59 additions & 57 deletions autoPyTorch/data/tabular_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,67 +14,51 @@
from sklearn.exceptions import NotFittedError
from sklearn.impute import SimpleImputer
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.preprocessing import OrdinalEncoder

from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SUPPORTED_FEAT_TYPES


def _create_column_transformer(
preprocessors: Dict[str, List[BaseEstimator]],
numerical_columns: List[str],
categorical_columns: List[str],
) -> ColumnTransformer:
"""
Given a dictionary of preprocessors, this function
creates a sklearn column transformer with appropriate
columns associated with their preprocessors.
Args:
preprocessors (Dict[str, List[BaseEstimator]]):
Dictionary containing list of numerical and categorical preprocessors.
numerical_columns (List[str]):
List of names of numerical columns
categorical_columns (List[str]):
List of names of categorical columns
Returns:
ColumnTransformer
"""

numerical_pipeline = 'drop'
categorical_pipeline = 'drop'
if len(numerical_columns) > 0:
numerical_pipeline = make_pipeline(*preprocessors['numerical'])
if len(categorical_columns) > 0:
categorical_pipeline = make_pipeline(*preprocessors['categorical'])
categorical_pipeline = make_pipeline(*preprocessors['categorical'])

return ColumnTransformer([
('categorical_pipeline', categorical_pipeline, categorical_columns),
('numerical_pipeline', numerical_pipeline, numerical_columns)],
remainder='drop'
('categorical_pipeline', categorical_pipeline, categorical_columns)],
remainder='passthrough'
)


def get_tabular_preprocessors() -> Dict[str, List[BaseEstimator]]:
"""
This function creates a Dictionary containing a list
of numerical and categorical preprocessors
Returns:
Dict[str, List[BaseEstimator]]
"""
preprocessors: Dict[str, List[BaseEstimator]] = dict()

# Categorical Preprocessors
onehot_encoder = OneHotEncoder(categories='auto', sparse=False, handle_unknown='ignore')
ordinal_encoder = OrdinalEncoder(handle_unknown='use_encoded_value',
unknown_value=-1)
categorical_imputer = SimpleImputer(strategy='constant', copy=False)

# Numerical Preprocessors
numerical_imputer = SimpleImputer(strategy='median', copy=False)
standard_scaler = StandardScaler(with_mean=True, with_std=True, copy=False)

preprocessors['categorical'] = [categorical_imputer, onehot_encoder]
preprocessors['numerical'] = [numerical_imputer, standard_scaler]
preprocessors['categorical'] = [categorical_imputer, ordinal_encoder]

return preprocessors

Expand Down Expand Up @@ -161,31 +145,48 @@ def _fit(

X = cast(pd.DataFrame, X)

self.all_nan_columns = set([column for column in X.columns if X[column].isna().all()])

categorical_columns, numerical_columns, feat_type = self._get_columns_info(X)

self.enc_columns = categorical_columns
all_nan_columns = []
for column in X.columns:
if X[column].isna().all():
X[column] = pd.to_numeric(X[column])
# Also note this change in self.dtypes
if len(self.dtypes) != 0:
self.dtypes[list(X.columns).index(column)] = X[column].dtype.name
all_nan_columns.append(column)
self.all_nan_columns = set(all_nan_columns)

self.enc_columns, self.feat_type = self._get_columns_info(X)

if len(self.enc_columns) > 0:

preprocessors = get_tabular_preprocessors()
self.column_transformer = _create_column_transformer(
preprocessors=preprocessors,
categorical_columns=self.enc_columns,
)

preprocessors = get_tabular_preprocessors()
self.column_transformer = _create_column_transformer(
preprocessors=preprocessors,
numerical_columns=numerical_columns,
categorical_columns=categorical_columns,
)
# Mypy redefinition
assert self.column_transformer is not None
self.column_transformer.fit(X)

# Mypy redefinition
assert self.column_transformer is not None
self.column_transformer.fit(X)
# The column transformer reorders the feature types
# therefore, we need to change the order of columns as well
# This means categorical columns are shifted to the left

# The column transformer reorders the feature types
# therefore, we need to change the order of columns as well
# This means categorical columns are shifted to the left
self.feat_type = sorted(
self.feat_type,
key=functools.cmp_to_key(self._comparator)
)

self.feat_type = sorted(
feat_type,
key=functools.cmp_to_key(self._comparator)
)
encoded_categories = self.column_transformer.\
named_transformers_['categorical_pipeline'].\
named_steps['ordinalencoder'].categories_
self.categories = [
# We fit an ordinal encoder, where all categorical
# columns are shifted to the left
list(range(len(cat)))
for cat in encoded_categories
]

# differently to categorical_columns and numerical_columns,
# this saves the index of the column.
Expand Down Expand Up @@ -264,6 +265,17 @@ def transform(

if hasattr(X, "iloc") and not scipy.sparse.issparse(X):
X = cast(Type[pd.DataFrame], X)
if self.all_nan_columns is not None:
for column in X.columns:
if column in self.all_nan_columns:
if not X[column].isna().all():
X[column] = np.nan
X[column] = pd.to_numeric(X[column])
if len(self.categorical_columns) > 0:
categorical_columns = self.column_transformer.transformers_[0][-1]
for column in categorical_columns:
if X[column].isna().all():
X[column] = X[column].astype('object')

# Check the data here so we catch problems on new test data
self._check_data(X)
Expand All @@ -273,11 +285,6 @@ def transform(
# We need to convert the column in test data to
# object otherwise the test column is interpreted as float
if self.column_transformer is not None:
if len(self.categorical_columns) > 0:
categorical_columns = self.column_transformer.transformers_[0][-1]
for column in categorical_columns:
if X[column].isna().all():
X[column] = X[column].astype('object')
X = self.column_transformer.transform(X)

# Sparse related transformations
Expand Down Expand Up @@ -362,18 +369,17 @@ def _check_data(

dtypes = [dtype.name for dtype in X.dtypes]

diff_cols = X.columns[[s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]]
if len(self.dtypes) == 0:
self.dtypes = dtypes
elif not self._is_datasets_consistent(diff_cols, X):
elif self.dtypes != dtypes:
raise ValueError("The dtype of the features must not be changed after fit(), but"
" the dtypes of some columns are different between training ({}) and"
" test ({}) datasets.".format(self.dtypes, dtypes))

def _get_columns_info(
self,
X: pd.DataFrame,
) -> Tuple[List[str], List[str], List[str]]:
) -> Tuple[List[str], List[str]]:
"""
Return the columns to be encoded from a pandas dataframe
Expand All @@ -392,15 +398,12 @@ def _get_columns_info(
"""

# Register if a column needs encoding
numerical_columns = []
categorical_columns = []
# Also, register the feature types for the estimator
feat_type = []

# Make sure each column is a valid type
for i, column in enumerate(X.columns):
if self.all_nan_columns is not None and column in self.all_nan_columns:
continue
column_dtype = self.dtypes[i]
err_msg = "Valid types are `numerical`, `categorical` or `boolean`, " \
"but input column {} has an invalid type `{}`.".format(column, column_dtype)
Expand All @@ -411,7 +414,6 @@ def _get_columns_info(
# TypeError: data type not understood in certain pandas types
elif is_numeric_dtype(column_dtype):
feat_type.append('numerical')
numerical_columns.append(column)
elif column_dtype == 'object':
# TODO verify how would this happen when we always convert the object dtypes to category
raise TypeError(
Expand All @@ -437,7 +439,7 @@ def _get_columns_info(
"before feeding it to AutoPyTorch.".format(err_msg)
)

return categorical_columns, numerical_columns, feat_type
return categorical_columns, feat_type

def list_to_pandas(
self,
Expand Down
2 changes: 1 addition & 1 deletion autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(

# TODO: Look for a criteria to define small enough to preprocess
# False for the regularization cocktails initially
self.is_small_preprocess = False
# self.is_small_preprocess = False

# Make sure cross validation splits are created once
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import numpy as np

from sklearn.compose import ColumnTransformer
# from sklearn.pipeline import make_pipeline
from sklearn.pipeline import make_pipeline

import torch

from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.base_tabular_preprocessing import (
autoPyTorchTabularPreprocessingComponent
)
# from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.utils import get_tabular_preprocessers
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.utils import get_tabular_preprocessers
from autoPyTorch.utils.common import FitRequirement, subsampler


Expand Down Expand Up @@ -52,11 +52,11 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "TabularColumnTransformer":
numerical_pipeline = 'passthrough'
categorical_pipeline = 'passthrough'

# preprocessors = get_tabular_preprocessers(X)
# if len(X['dataset_properties']['numerical_columns']):
# numerical_pipeline = make_pipeline(*preprocessors['numerical'])
# if len(X['dataset_properties']['categorical_columns']):
# categorical_pipeline = make_pipeline(*preprocessors['categorical'])
preprocessors = get_tabular_preprocessers(X)
if len(X['dataset_properties']['numerical_columns']):
numerical_pipeline = make_pipeline(*preprocessors['numerical'])
if len(X['dataset_properties']['categorical_columns']):
categorical_pipeline = make_pipeline(*preprocessors['categorical'])

self.preprocessor = ColumnTransformer([
('numerical_pipeline', numerical_pipeline, X['dataset_properties']['numerical_columns']),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
Returns:
(Dict[str, Any]): the updated 'X' dictionary
"""
# X.update({'encoder': self.preprocessor})
X.update({'encoder': self.preprocessor})
return X

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
if self.preprocessor['numerical'] is None and self.preprocessor['categorical'] is None:
raise ValueError("cant call transform on {} without fitting first."
.format(self.__class__.__name__))
# X.update({'encoder': self.preprocessor})
X.update({'encoder': self.preprocessor})
return X
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
if self.preprocessor['numerical'] is None and self.preprocessor['categorical'] is None:
raise ValueError("cant call transform on {} without fitting first."
.format(self.__class__.__name__))
# X.update({'imputer': self.preprocessor})
X.update({'imputer': self.preprocessor})
return X
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
Returns:
np.ndarray: Transformed features
"""
# X.update({'scaler': self.preprocessor})
X.update({'scaler': self.preprocessor})
return X

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
if self.preprocessor['numerical'] is None and self.preprocessor['categorical'] is None:
raise ValueError("cant call transform on {} without fitting first."
.format(self.__class__.__name__))
# X.update({'scaler': self.preprocessor})
X.update({'scaler': self.preprocessor})
return X
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, random_state: Optional[np.random.RandomState] = None) -> None
super().__init__()
self.random_state = random_state
self.add_fit_requirements([
FitRequirement('is_small_preprocess', (bool,), user_defined=True, dataset_property=True),
# FitRequirement('is_small_preprocess', (bool,), user_defined=True, dataset_property=True),
FitRequirement('X_train', (np.ndarray, pd.DataFrame, csr_matrix), user_defined=True,
dataset_property=False)])

Expand All @@ -32,14 +32,13 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "EarlyPreprocessing":
def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:

transforms = get_preprocess_transforms(X)
if X['dataset_properties']['is_small_preprocess']:
if 'X_train' in X:
X_train = X['X_train']
else:
# Incorporate the transform to the dataset
X_train = X['backend'].load_datamanager().train_tensors[0]

X['X_train'] = preprocess(dataset=X_train, transforms=transforms)
if 'X_train' in X:
X_train = X['X_train']
else:
# Incorporate the transform to the dataset
X_train = X['backend'].load_datamanager().train_tensors[0]

X['X_train'] = preprocess(dataset=X_train, transforms=transforms)

# We need to also save the preprocess transforms for inference
X.update({'preprocess_transforms': transforms})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self,
**kwargs: Any):
super().__init__()
self.add_fit_requirements([
FitRequirement('is_small_preprocess', (bool,), user_defined=True, dataset_property=True),
# FitRequirement('is_small_preprocess', (bool,), user_defined=True, dataset_property=True),
FitRequirement('X_train', (np.ndarray, pd.DataFrame, csr_matrix), user_defined=True,
dataset_property=False),
FitRequirement('input_shape', (Iterable,), user_defined=True, dataset_property=True),
Expand All @@ -52,12 +52,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
self.check_requirements(X, y)
X_train = X['X_train']

if X["dataset_properties"]["is_small_preprocess"]:
input_shape = X_train.shape[1:]
else:
# get input shape by transforming first two elements of the training set
column_transformer = X['tabular_transformer'].preprocessor
input_shape = column_transformer.transform(X_train[:1]).shape[1:]
input_shape = X_train.shape[1:]

input_shape = get_output_shape(X['network_embedding'], input_shape=input_shape)
self.input_shape = input_shape
Expand Down
Loading

0 comments on commit 23d5fd4

Please sign in to comment.