Skip to content

Commit

Permalink
[FIX] Enable preprocessing in reg_cocktails (#369)
Browse files Browse the repository at this point in the history
* enable preprocessing and remove is_small_preprocess

* address comments from shuhei and fix precommit checks

* fix tests

* fix precommit checks

* add suggestions from shuhei for astype use

* address speed issue when using object_dtype_mapping

* make code more readable

* improve documentation for base network embedding
  • Loading branch information
ravinkohli committed Feb 28, 2022
1 parent abd1588 commit 6cbe591
Show file tree
Hide file tree
Showing 34 changed files with 172 additions and 793 deletions.
1 change: 0 additions & 1 deletion autoPyTorch/api/tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.datasets.resampling_strategy import (
HoldoutValTypes,
CrossValTypes,
ResamplingStrategies,
)
from autoPyTorch.datasets.tabular_dataset import TabularDataset
Expand Down
1 change: 0 additions & 1 deletion autoPyTorch/api/tabular_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.datasets.resampling_strategy import (
HoldoutValTypes,
CrossValTypes,
ResamplingStrategies,
)
from autoPyTorch.datasets.tabular_dataset import TabularDataset
Expand Down
147 changes: 80 additions & 67 deletions autoPyTorch/data/tabular_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sklearn.exceptions import NotFittedError
from sklearn.impute import SimpleImputer
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.preprocessing import OrdinalEncoder

from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SupportedFeatTypes
from autoPyTorch.data.utils import (
Expand All @@ -28,7 +28,6 @@

def _create_column_transformer(
preprocessors: Dict[str, List[BaseEstimator]],
numerical_columns: List[str],
categorical_columns: List[str],
) -> ColumnTransformer:
"""
Expand All @@ -39,49 +38,36 @@ def _create_column_transformer(
Args:
preprocessors (Dict[str, List[BaseEstimator]]):
Dictionary containing list of numerical and categorical preprocessors.
numerical_columns (List[str]):
List of names of numerical columns
categorical_columns (List[str]):
List of names of categorical columns
Returns:
ColumnTransformer
"""

numerical_pipeline = 'drop'
categorical_pipeline = 'drop'
if len(numerical_columns) > 0:
numerical_pipeline = make_pipeline(*preprocessors['numerical'])
if len(categorical_columns) > 0:
categorical_pipeline = make_pipeline(*preprocessors['categorical'])
categorical_pipeline = make_pipeline(*preprocessors['categorical'])

return ColumnTransformer([
('categorical_pipeline', categorical_pipeline, categorical_columns),
('numerical_pipeline', numerical_pipeline, numerical_columns)],
remainder='drop'
('categorical_pipeline', categorical_pipeline, categorical_columns)],
remainder='passthrough'
)


def get_tabular_preprocessors() -> Dict[str, List[BaseEstimator]]:
"""
This function creates a Dictionary containing a list
of numerical and categorical preprocessors
Returns:
Dict[str, List[BaseEstimator]]
"""
preprocessors: Dict[str, List[BaseEstimator]] = dict()

# Categorical Preprocessors
onehot_encoder = OneHotEncoder(categories='auto', sparse=False, handle_unknown='ignore')
ordinal_encoder = OrdinalEncoder(handle_unknown='use_encoded_value',
unknown_value=-1)
categorical_imputer = SimpleImputer(strategy='constant', copy=False)

# Numerical Preprocessors
numerical_imputer = SimpleImputer(strategy='median', copy=False)
standard_scaler = StandardScaler(with_mean=True, with_std=True, copy=False)

preprocessors['categorical'] = [categorical_imputer, onehot_encoder]
preprocessors['numerical'] = [numerical_imputer, standard_scaler]
preprocessors['categorical'] = [categorical_imputer, ordinal_encoder]

return preprocessors

Expand Down Expand Up @@ -176,31 +162,47 @@ def _fit(
if hasattr(X, "iloc") and not issparse(X):
X = cast(pd.DataFrame, X)

self.all_nan_columns = set([column for column in X.columns if X[column].isna().all()])
all_nan_columns = X.columns[X.isna().all()]
for col in all_nan_columns:
X[col] = pd.to_numeric(X[col])

# Handle objects if possible
exist_object_columns = has_object_columns(X.dtypes.values)
if exist_object_columns:
X = self.infer_objects(X)

categorical_columns, numerical_columns, feat_type = self._get_columns_info(X)
self.dtypes = [dt.name for dt in X.dtypes] # Also note this change in self.dtypes
self.all_nan_columns = set(all_nan_columns)

self.enc_columns = categorical_columns
self.enc_columns, self.feat_type = self._get_columns_info(X)

preprocessors = get_tabular_preprocessors()
self.column_transformer = _create_column_transformer(
preprocessors=preprocessors,
numerical_columns=numerical_columns,
categorical_columns=categorical_columns,
)
if len(self.enc_columns) > 0:

# Mypy redefinition
assert self.column_transformer is not None
self.column_transformer.fit(X)
preprocessors = get_tabular_preprocessors()
self.column_transformer = _create_column_transformer(
preprocessors=preprocessors,
categorical_columns=self.enc_columns,
)

# The column transformer reorders the feature types
# therefore, we need to change the order of columns as well
# This means categorical columns are shifted to the left
# Mypy redefinition
assert self.column_transformer is not None
self.column_transformer.fit(X)

self.feat_type = sorted(
feat_type,
key=functools.cmp_to_key(self._comparator)
)
# The column transformer moves categorical columns before all numerical columns
# therefore, we need to sort categorical columns so that it complies this change

self.feat_type = sorted(
self.feat_type,
key=functools.cmp_to_key(self._comparator)
)

encoded_categories = self.column_transformer.\
named_transformers_['categorical_pipeline'].\
named_steps['ordinalencoder'].categories_
self.categories = [
list(range(len(cat)))
for cat in encoded_categories
]

# differently to categorical_columns and numerical_columns,
# this saves the index of the column.
Expand Down Expand Up @@ -280,6 +282,23 @@ def transform(
if hasattr(X, "iloc") and not scipy.sparse.issparse(X):
X = cast(Type[pd.DataFrame], X)

if self.all_nan_columns is None:
raise ValueError('_fit must be called before calling transform')

for col in list(self.all_nan_columns):
X[col] = np.nan
X[col] = pd.to_numeric(X[col])

if len(self.categorical_columns) > 0:
# when some categorical columns are not all nan in the training set
# but they are all nan in the testing or validation set
# we change those columns to `object` dtype
# to ensure that these columns are changed to appropriate dtype
# in self.infer_objects
all_nan_cat_cols = set(X[self.enc_columns].columns[X[self.enc_columns].isna().all()])
dtype_dict = {col: 'object' for col in self.enc_columns if col in all_nan_cat_cols}
X = X.astype(dtype_dict)

# Check the data here so we catch problems on new test data
self._check_data(X)

Expand All @@ -288,11 +307,6 @@ def transform(
# We need to convert the column in test data to
# object otherwise the test column is interpreted as float
if self.column_transformer is not None:
if len(self.categorical_columns) > 0:
categorical_columns = self.column_transformer.transformers_[0][-1]
for column in categorical_columns:
if X[column].isna().all():
X[column] = X[column].astype('object')
X = self.column_transformer.transform(X)

# Sparse related transformations
Expand Down Expand Up @@ -407,7 +421,6 @@ def _check_data(
self.column_order = column_order

dtypes = [dtype.name for dtype in X.dtypes]

diff_cols = X.columns[[s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]]
if len(self.dtypes) == 0:
self.dtypes = dtypes
Expand All @@ -419,7 +432,7 @@ def _check_data(
def _get_columns_info(
self,
X: pd.DataFrame,
) -> Tuple[List[str], List[str], List[str]]:
) -> Tuple[List[str], List[str]]:
"""
Return the columns to be encoded from a pandas dataframe
Expand All @@ -438,15 +451,12 @@ def _get_columns_info(
"""

# Register if a column needs encoding
numerical_columns = []
categorical_columns = []
# Also, register the feature types for the estimator
feat_type = []

# Make sure each column is a valid type
for i, column in enumerate(X.columns):
if self.all_nan_columns is not None and column in self.all_nan_columns:
continue
column_dtype = self.dtypes[i]
err_msg = "Valid types are `numerical`, `categorical` or `boolean`, " \
"but input column {} has an invalid type `{}`.".format(column, column_dtype)
Expand All @@ -457,7 +467,6 @@ def _get_columns_info(
# TypeError: data type not understood in certain pandas types
elif is_numeric_dtype(column_dtype):
feat_type.append('numerical')
numerical_columns.append(column)
elif column_dtype == 'object':
# TODO verify how would this happen when we always convert the object dtypes to category
raise TypeError(
Expand All @@ -483,7 +492,7 @@ def _get_columns_info(
"before feeding it to AutoPyTorch.".format(err_msg)
)

return categorical_columns, numerical_columns, feat_type
return categorical_columns, feat_type

def list_to_pandas(
self,
Expand Down Expand Up @@ -553,22 +562,26 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
pd.DataFrame
"""
if hasattr(self, 'object_dtype_mapping'):
# Mypy does not process the has attr. This dict is defined below
for key, dtype in self.object_dtype_mapping.items(): # type: ignore[has-type]
# honor the training data types
try:
X[key] = X[key].astype(dtype.name)
except Exception as e:
# Try inference if possible
self.logger.warning(f'Casting the column {key} to {dtype} caused the exception {e}')
pass
# honor the training data types
try:
# Mypy does not process the has attr.
X = X.astype(self.object_dtype_mapping) # type: ignore[has-type]
except Exception as e:
# Try inference if possible
self.logger.warning(f'Casting the columns to training dtypes ' # type: ignore[has-type]
f'{self.object_dtype_mapping} caused the exception {e}')
pass
else:
# Calling for the first time to infer the categories
X = X.infer_objects()
for column, data_type in zip(X.columns, X.dtypes):
if not is_numeric_dtype(data_type):
X[column] = X[column].astype('category')

if len(self.dtypes) != 0:
# when train data has no object dtype, but test does
# we prioritise the datatype given in training data
dtype_dict = {col: dtype for col, dtype in zip(X.columns, self.dtypes)}
X = X.astype(dtype_dict)
else:
# Calling for the first time to infer the categories
X = X.infer_objects()
dtype_dict = {col: 'category' for col, dtype in zip(X.columns, X.dtypes) if not is_numeric_dtype(dtype)}
X = X.astype(dtype_dict)
# only numerical attributes and categories
self.object_dtype_mapping = {column: data_type for column, data_type in zip(X.columns, X.dtypes)}

Expand Down
5 changes: 0 additions & 5 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def __init__(
self.holdout_validators: Dict[str, HoldOutFunc] = {}
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
self.random_state = np.random.RandomState(seed=seed)
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
self.shuffle = shuffle
self.resampling_strategy = resampling_strategy
self.resampling_strategy_args = resampling_strategy_args
Expand All @@ -165,10 +164,6 @@ def __init__(
if len(self.train_tensors) == 2 and self.train_tensors[1] is not None:
self.output_shape, self.output_type = _get_output_properties(self.train_tensors)

# TODO: Look for a criteria to define small enough to preprocess
# False for the regularization cocktails initially
self.is_small_preprocess = False

# Make sure cross validation splits are created once
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
Expand Down
7 changes: 0 additions & 7 deletions autoPyTorch/datasets/resampling_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@ def __call__(self, random_state: np.random.RandomState, val_share: float,
...


class NoResamplingFunc(Protocol):
def __call__(self,
random_state: np.random.RandomState,
indices: np.ndarray) -> np.ndarray:
...


class CrossValTypes(IntEnum):
"""The type of cross validation
Expand Down
Loading

0 comments on commit 6cbe591

Please sign in to comment.