From 23d5fd41b57f75b00881cd95a8a74acaffe6a682 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 24 Jan 2022 18:52:48 +0100 Subject: [PATCH] enable preprocessing and remove is_small_preprocess --- autoPyTorch/data/tabular_feature_validator.py | 116 +++++++++--------- autoPyTorch/datasets/base_dataset.py | 2 +- .../TabularColumnTransformer.py | 14 +-- .../encoding/NoEncoder.py | 2 +- .../encoding/base_encoder.py | 2 +- .../imputation/base_imputer.py | 2 +- .../tabular_preprocessing/scaling/NoScaler.py | 2 +- .../scaling/base_scaler.py | 2 +- .../early_preprocessor/EarlyPreprocessing.py | 17 ++- .../network_backbone/base_network_backbone.py | 9 +- .../base_network_embedding.py | 39 +++--- .../training/data_loader/base_data_loader.py | 16 +-- .../data_loader/feature_data_loader.py | 5 +- .../training/data_loader/image_data_loader.py | 6 +- test/test_data/test_feature_validator.py | 54 ++++---- test/test_datasets/test_tabular_dataset.py | 1 - .../components/preprocessing/test_encoders.py | 2 +- .../components/preprocessing/test_imputers.py | 2 +- .../components/preprocessing/test_scalers.py | 8 +- .../test_tabular_column_transformer.py | 9 +- .../training/test_feature_data_loader.py | 20 +-- .../components/training/test_training.py | 5 - 22 files changed, 155 insertions(+), 180 deletions(-) diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 6895e8478..c58af736d 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -14,44 +14,33 @@ from sklearn.exceptions import NotFittedError from sklearn.impute import SimpleImputer from sklearn.pipeline import make_pipeline -from sklearn.preprocessing import OneHotEncoder, StandardScaler +from sklearn.preprocessing import OrdinalEncoder from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SUPPORTED_FEAT_TYPES def _create_column_transformer( preprocessors: Dict[str, List[BaseEstimator]], - numerical_columns: List[str], categorical_columns: List[str], ) -> ColumnTransformer: """ Given a dictionary of preprocessors, this function creates a sklearn column transformer with appropriate columns associated with their preprocessors. - Args: preprocessors (Dict[str, List[BaseEstimator]]): Dictionary containing list of numerical and categorical preprocessors. - numerical_columns (List[str]): - List of names of numerical columns categorical_columns (List[str]): List of names of categorical columns - Returns: ColumnTransformer """ - numerical_pipeline = 'drop' - categorical_pipeline = 'drop' - if len(numerical_columns) > 0: - numerical_pipeline = make_pipeline(*preprocessors['numerical']) - if len(categorical_columns) > 0: - categorical_pipeline = make_pipeline(*preprocessors['categorical']) + categorical_pipeline = make_pipeline(*preprocessors['categorical']) return ColumnTransformer([ - ('categorical_pipeline', categorical_pipeline, categorical_columns), - ('numerical_pipeline', numerical_pipeline, numerical_columns)], - remainder='drop' + ('categorical_pipeline', categorical_pipeline, categorical_columns)], + remainder='passthrough' ) @@ -59,22 +48,17 @@ def get_tabular_preprocessors() -> Dict[str, List[BaseEstimator]]: """ This function creates a Dictionary containing a list of numerical and categorical preprocessors - Returns: Dict[str, List[BaseEstimator]] """ preprocessors: Dict[str, List[BaseEstimator]] = dict() # Categorical Preprocessors - onehot_encoder = OneHotEncoder(categories='auto', sparse=False, handle_unknown='ignore') + ordinal_encoder = OrdinalEncoder(handle_unknown='use_encoded_value', + unknown_value=-1) categorical_imputer = SimpleImputer(strategy='constant', copy=False) - # Numerical Preprocessors - numerical_imputer = SimpleImputer(strategy='median', copy=False) - standard_scaler = StandardScaler(with_mean=True, with_std=True, copy=False) - - preprocessors['categorical'] = [categorical_imputer, onehot_encoder] - preprocessors['numerical'] = [numerical_imputer, standard_scaler] + preprocessors['categorical'] = [categorical_imputer, ordinal_encoder] return preprocessors @@ -161,31 +145,48 @@ def _fit( X = cast(pd.DataFrame, X) - self.all_nan_columns = set([column for column in X.columns if X[column].isna().all()]) - - categorical_columns, numerical_columns, feat_type = self._get_columns_info(X) - - self.enc_columns = categorical_columns + all_nan_columns = [] + for column in X.columns: + if X[column].isna().all(): + X[column] = pd.to_numeric(X[column]) + # Also note this change in self.dtypes + if len(self.dtypes) != 0: + self.dtypes[list(X.columns).index(column)] = X[column].dtype.name + all_nan_columns.append(column) + self.all_nan_columns = set(all_nan_columns) + + self.enc_columns, self.feat_type = self._get_columns_info(X) + + if len(self.enc_columns) > 0: + + preprocessors = get_tabular_preprocessors() + self.column_transformer = _create_column_transformer( + preprocessors=preprocessors, + categorical_columns=self.enc_columns, + ) - preprocessors = get_tabular_preprocessors() - self.column_transformer = _create_column_transformer( - preprocessors=preprocessors, - numerical_columns=numerical_columns, - categorical_columns=categorical_columns, - ) + # Mypy redefinition + assert self.column_transformer is not None + self.column_transformer.fit(X) - # Mypy redefinition - assert self.column_transformer is not None - self.column_transformer.fit(X) + # The column transformer reorders the feature types + # therefore, we need to change the order of columns as well + # This means categorical columns are shifted to the left - # The column transformer reorders the feature types - # therefore, we need to change the order of columns as well - # This means categorical columns are shifted to the left + self.feat_type = sorted( + self.feat_type, + key=functools.cmp_to_key(self._comparator) + ) - self.feat_type = sorted( - feat_type, - key=functools.cmp_to_key(self._comparator) - ) + encoded_categories = self.column_transformer.\ + named_transformers_['categorical_pipeline'].\ + named_steps['ordinalencoder'].categories_ + self.categories = [ + # We fit an ordinal encoder, where all categorical + # columns are shifted to the left + list(range(len(cat))) + for cat in encoded_categories + ] # differently to categorical_columns and numerical_columns, # this saves the index of the column. @@ -264,6 +265,17 @@ def transform( if hasattr(X, "iloc") and not scipy.sparse.issparse(X): X = cast(Type[pd.DataFrame], X) + if self.all_nan_columns is not None: + for column in X.columns: + if column in self.all_nan_columns: + if not X[column].isna().all(): + X[column] = np.nan + X[column] = pd.to_numeric(X[column]) + if len(self.categorical_columns) > 0: + categorical_columns = self.column_transformer.transformers_[0][-1] + for column in categorical_columns: + if X[column].isna().all(): + X[column] = X[column].astype('object') # Check the data here so we catch problems on new test data self._check_data(X) @@ -273,11 +285,6 @@ def transform( # We need to convert the column in test data to # object otherwise the test column is interpreted as float if self.column_transformer is not None: - if len(self.categorical_columns) > 0: - categorical_columns = self.column_transformer.transformers_[0][-1] - for column in categorical_columns: - if X[column].isna().all(): - X[column] = X[column].astype('object') X = self.column_transformer.transform(X) # Sparse related transformations @@ -362,10 +369,9 @@ def _check_data( dtypes = [dtype.name for dtype in X.dtypes] - diff_cols = X.columns[[s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]] if len(self.dtypes) == 0: self.dtypes = dtypes - elif not self._is_datasets_consistent(diff_cols, X): + elif self.dtypes != dtypes: raise ValueError("The dtype of the features must not be changed after fit(), but" " the dtypes of some columns are different between training ({}) and" " test ({}) datasets.".format(self.dtypes, dtypes)) @@ -373,7 +379,7 @@ def _check_data( def _get_columns_info( self, X: pd.DataFrame, - ) -> Tuple[List[str], List[str], List[str]]: + ) -> Tuple[List[str], List[str]]: """ Return the columns to be encoded from a pandas dataframe @@ -392,15 +398,12 @@ def _get_columns_info( """ # Register if a column needs encoding - numerical_columns = [] categorical_columns = [] # Also, register the feature types for the estimator feat_type = [] # Make sure each column is a valid type for i, column in enumerate(X.columns): - if self.all_nan_columns is not None and column in self.all_nan_columns: - continue column_dtype = self.dtypes[i] err_msg = "Valid types are `numerical`, `categorical` or `boolean`, " \ "but input column {} has an invalid type `{}`.".format(column, column_dtype) @@ -411,7 +414,6 @@ def _get_columns_info( # TypeError: data type not understood in certain pandas types elif is_numeric_dtype(column_dtype): feat_type.append('numerical') - numerical_columns.append(column) elif column_dtype == 'object': # TODO verify how would this happen when we always convert the object dtypes to category raise TypeError( @@ -437,7 +439,7 @@ def _get_columns_info( "before feeding it to AutoPyTorch.".format(err_msg) ) - return categorical_columns, numerical_columns, feat_type + return categorical_columns, feat_type def list_to_pandas( self, diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index 0c48ac06d..f65784840 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -146,7 +146,7 @@ def __init__( # TODO: Look for a criteria to define small enough to preprocess # False for the regularization cocktails initially - self.is_small_preprocess = False + # self.is_small_preprocess = False # Make sure cross validation splits are created once self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes) diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py index e8f95ab57..05bede68a 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py @@ -3,14 +3,14 @@ import numpy as np from sklearn.compose import ColumnTransformer -# from sklearn.pipeline import make_pipeline +from sklearn.pipeline import make_pipeline import torch from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.base_tabular_preprocessing import ( autoPyTorchTabularPreprocessingComponent ) -# from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.utils import get_tabular_preprocessers +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.utils import get_tabular_preprocessers from autoPyTorch.utils.common import FitRequirement, subsampler @@ -52,11 +52,11 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "TabularColumnTransformer": numerical_pipeline = 'passthrough' categorical_pipeline = 'passthrough' - # preprocessors = get_tabular_preprocessers(X) - # if len(X['dataset_properties']['numerical_columns']): - # numerical_pipeline = make_pipeline(*preprocessors['numerical']) - # if len(X['dataset_properties']['categorical_columns']): - # categorical_pipeline = make_pipeline(*preprocessors['categorical']) + preprocessors = get_tabular_preprocessers(X) + if len(X['dataset_properties']['numerical_columns']): + numerical_pipeline = make_pipeline(*preprocessors['numerical']) + if len(X['dataset_properties']['categorical_columns']): + categorical_pipeline = make_pipeline(*preprocessors['categorical']) self.preprocessor = ColumnTransformer([ ('numerical_pipeline', numerical_pipeline, X['dataset_properties']['numerical_columns']), diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/NoEncoder.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/NoEncoder.py index d62ee26d2..929e99048 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/NoEncoder.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/NoEncoder.py @@ -40,7 +40,7 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: Returns: (Dict[str, Any]): the updated 'X' dictionary """ - # X.update({'encoder': self.preprocessor}) + X.update({'encoder': self.preprocessor}) return X @staticmethod diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder.py index 9829cadcd..eadc0a188 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder.py @@ -28,5 +28,5 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: if self.preprocessor['numerical'] is None and self.preprocessor['categorical'] is None: raise ValueError("cant call transform on {} without fitting first." .format(self.__class__.__name__)) - # X.update({'encoder': self.preprocessor}) + X.update({'encoder': self.preprocessor}) return X diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/imputation/base_imputer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/imputation/base_imputer.py index ac0648481..b65f3c229 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/imputation/base_imputer.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/imputation/base_imputer.py @@ -29,5 +29,5 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: if self.preprocessor['numerical'] is None and self.preprocessor['categorical'] is None: raise ValueError("cant call transform on {} without fitting first." .format(self.__class__.__name__)) - # X.update({'imputer': self.preprocessor}) + X.update({'imputer': self.preprocessor}) return X diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/NoScaler.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/NoScaler.py index 9775d17dd..9d50aa8f5 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/NoScaler.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/NoScaler.py @@ -43,7 +43,7 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: Returns: np.ndarray: Transformed features """ - # X.update({'scaler': self.preprocessor}) + X.update({'scaler': self.preprocessor}) return X @staticmethod diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/base_scaler.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/base_scaler.py index 270fac246..39834dd2b 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/base_scaler.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/base_scaler.py @@ -28,5 +28,5 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: if self.preprocessor['numerical'] is None and self.preprocessor['categorical'] is None: raise ValueError("cant call transform on {} without fitting first." .format(self.__class__.__name__)) - # X.update({'scaler': self.preprocessor}) + X.update({'scaler': self.preprocessor}) return X diff --git a/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py b/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py index 7fbf33f99..c25ea6bb0 100644 --- a/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py +++ b/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py @@ -20,7 +20,7 @@ def __init__(self, random_state: Optional[np.random.RandomState] = None) -> None super().__init__() self.random_state = random_state self.add_fit_requirements([ - FitRequirement('is_small_preprocess', (bool,), user_defined=True, dataset_property=True), + # FitRequirement('is_small_preprocess', (bool,), user_defined=True, dataset_property=True), FitRequirement('X_train', (np.ndarray, pd.DataFrame, csr_matrix), user_defined=True, dataset_property=False)]) @@ -32,14 +32,13 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "EarlyPreprocessing": def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: transforms = get_preprocess_transforms(X) - if X['dataset_properties']['is_small_preprocess']: - if 'X_train' in X: - X_train = X['X_train'] - else: - # Incorporate the transform to the dataset - X_train = X['backend'].load_datamanager().train_tensors[0] - - X['X_train'] = preprocess(dataset=X_train, transforms=transforms) + if 'X_train' in X: + X_train = X['X_train'] + else: + # Incorporate the transform to the dataset + X_train = X['backend'].load_datamanager().train_tensors[0] + + X['X_train'] = preprocess(dataset=X_train, transforms=transforms) # We need to also save the preprocess transforms for inference X.update({'preprocess_transforms': transforms}) diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py index 1a04d6645..50050acaf 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py @@ -28,7 +28,7 @@ def __init__(self, **kwargs: Any): super().__init__() self.add_fit_requirements([ - FitRequirement('is_small_preprocess', (bool,), user_defined=True, dataset_property=True), + # FitRequirement('is_small_preprocess', (bool,), user_defined=True, dataset_property=True), FitRequirement('X_train', (np.ndarray, pd.DataFrame, csr_matrix), user_defined=True, dataset_property=False), FitRequirement('input_shape', (Iterable,), user_defined=True, dataset_property=True), @@ -52,12 +52,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: self.check_requirements(X, y) X_train = X['X_train'] - if X["dataset_properties"]["is_small_preprocess"]: - input_shape = X_train.shape[1:] - else: - # get input shape by transforming first two elements of the training set - column_transformer = X['tabular_transformer'].preprocessor - input_shape = column_transformer.transform(X_train[:1]).shape[1:] + input_shape = X_train.shape[1:] input_shape = get_output_shape(X['network_embedding'], input_shape=input_shape) self.input_shape = input_shape diff --git a/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py b/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py index 844a4616b..53a1b19e9 100644 --- a/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py +++ b/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py @@ -1,4 +1,4 @@ -# import copy +import copy from typing import Any, Dict, Optional, Tuple import numpy as np @@ -21,7 +21,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: self.embedding = self.build_embedding( num_input_features=num_input_features, - num_numerical_features=num_numerical_columns) # type: ignore[arg-type] + num_numerical_features=num_numerical_columns) return self def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: @@ -31,22 +31,21 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: def build_embedding(self, num_input_features: np.ndarray, num_numerical_features: int) -> nn.Module: raise NotImplementedError - def _get_args(self, X: Dict[str, Any]) -> Tuple[None, None]: # Tuple[int, np.ndarray]: + def _get_args(self, X: Dict[str, Any]) -> Tuple[int, np.ndarray]: # Feature preprocessors can alter numerical columns - # if len(X['dataset_properties']['numerical_columns']) == 0: - # num_numerical_columns = 0 - # else: - # X_train = copy.deepcopy(X['backend'].load_datamanager().train_tensors[0][:2]) - # - # numerical_column_transformer = X['tabular_transformer'].preprocessor. \ - # named_transformers_['numerical_pipeline'] - # num_numerical_columns = numerical_column_transformer.transform( - # X_train[:, X['dataset_properties']['numerical_columns']]).shape[1] - # num_input_features = np.zeros((num_numerical_columns + len(X['dataset_properties']['categorical_columns'])), - # dtype=int) - # categories = X['dataset_properties']['categories'] - # - # for i, category in enumerate(categories): - # num_input_features[num_numerical_columns + i, ] = len(category) - # return num_numerical_columns, num_input_features - return None, None + if len(X['dataset_properties']['numerical_columns']) == 0: + num_numerical_columns = 0 + else: + X_train = copy.deepcopy(X['backend'].load_datamanager().train_tensors[0][:2]) + + numerical_column_transformer = X['tabular_transformer'].preprocessor. \ + named_transformers_['numerical_pipeline'] + num_numerical_columns = numerical_column_transformer.transform( + X_train[:, X['dataset_properties']['numerical_columns']]).shape[1] + num_input_features = np.zeros((num_numerical_columns + len(X['dataset_properties']['categorical_columns'])), + dtype=int) + categories = X['dataset_properties']['categories'] + + for i, category in enumerate(categories): + num_input_features[num_numerical_columns + i, ] = len(category) + return num_numerical_columns, num_input_features diff --git a/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py b/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py index 058ffe904..a11882816 100644 --- a/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py +++ b/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py @@ -57,7 +57,8 @@ def __init__(self, batch_size: int = 64, self.add_fit_requirements([ FitRequirement("split_id", (int,), user_defined=True, dataset_property=False), FitRequirement("Backend", (Backend,), user_defined=True, dataset_property=False), - FitRequirement("is_small_preprocess", (bool,), user_defined=True, dataset_property=True)]) + # FitRequirement("is_small_preprocess", (bool,), user_defined=True, dataset_property=True) + ]) def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: """The transform function calls the transform function of the @@ -102,10 +103,9 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> torch.utils.data.DataLoader: self.val_transform, train=False, ) - if X['dataset_properties']["is_small_preprocess"]: - # This parameter indicates that the data has been pre-processed for speed - # Overwrite the datamanager with the pre-processes data - datamanager.replace_data(X['X_train'], X['X_test'] if 'X_test' in X else None) + # This parameter indicates that the data has been pre-processed for speed + # Overwrite the datamanager with the pre-processes data + datamanager.replace_data(X['X_train'], X['X_test'] if 'X_test' in X else None) train_dataset = datamanager.get_dataset_for_training(split_id=X['split_id'], train=True) @@ -226,9 +226,9 @@ def check_requirements(self, X: Dict[str, Any], y: Any = None) -> None: if 'backend' not in X: raise ValueError("backend is needed to load the data from disk") - if 'is_small_preprocess' not in X['dataset_properties']: - raise ValueError("is_small_pre-process is required to know if the data was preprocessed" - " or if the data-loader should transform it while loading a batch") + # if 'is_small_preprocess' not in X['dataset_properties']: + # raise ValueError("is_small_pre-process is required to know if the data was preprocessed" + # " or if the data-loader should transform it while loading a batch") # We expect this class to be a base for image/tabular/time # And the difference among this data types should be mainly diff --git a/autoPyTorch/pipeline/components/training/data_loader/feature_data_loader.py b/autoPyTorch/pipeline/components/training/data_loader/feature_data_loader.py index 4e41ec838..b76baf8cf 100644 --- a/autoPyTorch/pipeline/components/training/data_loader/feature_data_loader.py +++ b/autoPyTorch/pipeline/components/training/data_loader/feature_data_loader.py @@ -72,7 +72,7 @@ def build_transform(self, X: Dict[str, Any], mode: str) -> torchvision.transform # distinction is performed candidate_transformations: List[Callable] = [] - if 'test' in mode or not X['dataset_properties']['is_small_preprocess']: + if 'test' in mode: candidate_transformations.append((ExpandTransform())) candidate_transformations.extend(X['preprocess_transforms']) candidate_transformations.append((ContractTransform())) @@ -93,5 +93,4 @@ def _check_transform_requirements(self, X: Dict[str, Any], y: Any = None) -> Non mechanism, in which during a transform, a components adds relevant information so that further stages can be properly fitted """ - if not X['dataset_properties']['is_small_preprocess'] and 'preprocess_transforms' not in X: - raise ValueError("Cannot find the preprocess_transforms in the fit dictionary") + pass diff --git a/autoPyTorch/pipeline/components/training/data_loader/image_data_loader.py b/autoPyTorch/pipeline/components/training/data_loader/image_data_loader.py index 21cc05447..2dcf72ab8 100644 --- a/autoPyTorch/pipeline/components/training/data_loader/image_data_loader.py +++ b/autoPyTorch/pipeline/components/training/data_loader/image_data_loader.py @@ -41,7 +41,7 @@ def build_transform(self, X: Dict[str, Any], mode: str) -> torchvision.transform # check if data set is small enough to be preprocessed. # If it is, then no need to add preprocess_transforms to # the data loader as the data is already preprocessed - if 'test' in mode or not X['dataset_properties']['is_small_preprocess']: + if 'test' in mode: transformations.append(X['preprocess_transforms']) # Transform to tensor @@ -63,5 +63,5 @@ def _check_transform_requirements(self, X: Dict[str, Any], y: Any = None) -> Non if not X['image_augmenter'] and 'image_augmenter' not in X: raise ValueError("Cannot find the image_augmenter in the fit dictionary") - if not X['dataset_properties']['is_small_preprocess'] and 'preprocess_transforms' not in X: - raise ValueError("Cannot find the preprocess_transforms in the fit dictionary") + # if not X['dataset_properties']['is_small_preprocess'] and 'preprocess_transforms' not in X: + # raise ValueError("Cannot find the preprocess_transforms in the fit dictionary") diff --git a/test/test_data/test_feature_validator.py b/test/test_data/test_feature_validator.py index c8e05182c..59212fb28 100644 --- a/test/test_data/test_feature_validator.py +++ b/test/test_data/test_feature_validator.py @@ -237,7 +237,7 @@ def test_featurevalidator_categorical_nan(input_data_featuretest): transformed_X = validator.transform(input_data_featuretest) assert any(pd.isna(input_data_featuretest)) categories_ = validator.column_transformer.\ - named_transformers_['categorical_pipeline'].named_steps['onehotencoder'].categories_ + named_transformers_['categorical_pipeline'].named_steps['ordinalencoder'].categories_ assert any(('0' in categories) or (0 in categories) or ('missing_value' in categories) for categories in categories_) assert np.issubdtype(transformed_X.dtype, np.number) @@ -313,9 +313,8 @@ def test_featurevalidator_get_columns_to_encode(): validator.fit(df) - categorical_columns, numerical_columns, feat_type = validator._get_columns_info(df) + categorical_columns, feat_type = validator._get_columns_info(df) - assert numerical_columns == ['int', 'float'] assert categorical_columns == ['category', 'bool'] assert feat_type == ['numerical', 'numerical', 'categorical', 'categorical'] @@ -327,8 +326,8 @@ def feature_validator_remove_nan_catcolumns(df_train: pd.DataFrame, df_test: pd. transformed_df_train = validator.transform(df_train) transformed_df_test = validator.transform(df_test) - assert np.array_equal(transformed_df_train, ans_train) - assert np.array_equal(transformed_df_test, ans_test) + np.testing.assert_array_equal(transformed_df_train, ans_train) + np.testing.assert_array_equal(transformed_df_test, ans_test) def test_feature_validator_remove_nan_catcolumns(): @@ -373,7 +372,7 @@ def test_feature_validator_remove_nan_catcolumns(): ], dtype='category', ) - ans_train = np.array([[0, 1], [1, 0], [0, 1]], dtype=np.float64) + ans_train = np.array([[ 1, np.nan, np.nan], [ 0, np.nan, np.nan], [ 1, np.nan, np.nan]], dtype=np.float64) df_test = pd.DataFrame( [ {'A': np.nan, 'B': np.nan, 'C': 5}, @@ -382,7 +381,7 @@ def test_feature_validator_remove_nan_catcolumns(): ], dtype='category', ) - ans_test = np.array([[1, 0], [1, 0], [0, 1]], dtype=np.float64) + ans_test = np.array([[ 0, np.nan, 5], [ 0, np.nan, np.nan], [ 1, np.nan, np.nan]], dtype=np.float64) feature_validator_remove_nan_catcolumns(df_train, df_test, ans_train, ans_test) # Second case, there exist null columns (B and C) in the training set and @@ -395,7 +394,7 @@ def test_feature_validator_remove_nan_catcolumns(): ], dtype='category', ) - ans_train = np.array([[0, 1], [1, 0], [0, 1]], dtype=np.float64) + ans_train = np.array([[ 1, np.nan, np.nan], [ 0, np.nan, np.nan], [ 1, np.nan, np.nan]], dtype=np.float64) df_test = pd.DataFrame( [ {'A': np.nan, 'B': np.nan, 'C': np.nan}, @@ -404,7 +403,7 @@ def test_feature_validator_remove_nan_catcolumns(): ], dtype='category', ) - ans_test = np.array([[1, 0], [1, 0], [0, 1]], dtype=np.float64) + ans_test = np.array([[ 0, np.nan, np.nan], [ 0, np.nan, np.nan], [ 1, np.nan, np.nan]], dtype=np.float64) feature_validator_remove_nan_catcolumns(df_train, df_test, ans_train, ans_test) # Third case, there exist no null columns in the training set and @@ -416,7 +415,7 @@ def test_feature_validator_remove_nan_catcolumns(): ], dtype='category', ) - ans_train = np.array([[1, 0, 1, 0], [0, 1, 0, 1]], dtype=np.float64) + ans_train = np.array([[ 0, 0], [1, 1]], dtype=np.float64) df_test = pd.DataFrame( [ {'A': np.nan, 'B': np.nan}, @@ -424,7 +423,7 @@ def test_feature_validator_remove_nan_catcolumns(): ], dtype='category', ) - ans_test = np.array([[0, 0, 0, 0], [0, 0, 0, 0]], dtype=np.float64) + ans_test = np.array([[-1, -1], [-1, -1]], dtype=np.float64) feature_validator_remove_nan_catcolumns(df_train, df_test, ans_train, ans_test) @@ -504,22 +503,30 @@ def test_column_transformer_created(input_data_featuretest): # Make sure that the encoded features are actually encoded. Categorical columns are at # the start after transformation. In our fixtures, this is also honored prior encode - cat_columns, _, feature_types = validator._get_columns_info(input_data_featuretest) + cat_columns, feature_types = validator._get_columns_info(input_data_featuretest) # At least one categorical assert 'categorical' in validator.feat_type + # Numerical if the original data has numerical only columns # Numerical if the original data has numerical only columns if np.any([pd.api.types.is_numeric_dtype(input_data_featuretest[col] ) for col in input_data_featuretest.columns]): assert 'numerical' in validator.feat_type - # we expect this input to be the fixture 'pandas_mixed_nan' - np.testing.assert_array_equal(transformed_X, np.array([[1., 0., -1.], [0., 1., 1.]])) - else: - np.testing.assert_array_equal(transformed_X, np.array([[1., 0., 1., 0.], [0., 1., 0., 1.]])) - - if not all([feat_type in ['numerical', 'categorical'] for feat_type in feature_types]): - raise ValueError("Expected only numerical and categorical feature types") + for i, feat_type in enumerate(feature_types): + if 'numerical' in feat_type: + np.testing.assert_array_equal( + transformed_X[:, i], + input_data_featuretest[input_data_featuretest.columns[i]].to_numpy() + ) + elif 'categorical' in feat_type: + np.testing.assert_array_equal( + transformed_X[:, i], + # Expect always 0, 1... because we use a ordinal encoder + np.array([0, 1]) + ) + else: + raise ValueError(feat_type) def test_no_new_category_after_fit(): @@ -554,7 +561,7 @@ def test_unknown_encode_value(): # The first row should have a 0, 0 as we added a # new categorical there and one hot encoder marks # it as all zeros for the transformed column - expected_row = [0.0, 0.0, -0.5584294383572701, 0.5000000000000004, -1.5136598016833485] + expected_row = [-1, -41, -3, -987.2] assert expected_row == x_t[0].tolist() @@ -678,16 +685,11 @@ def test_feature_validator_imbalanced_data(): validator.fit(X_train) train_feature_types = copy.deepcopy(validator.feat_type) - assert train_feature_types == ['numerical'] + assert train_feature_types == ['numerical', 'numerical', 'numerical', 'numerical'] # validator will throw an error if the column types are not the same transformed_X_test = validator.transform(X_test) transformed_X_test = pd.DataFrame(transformed_X_test) assert sorted(validator.all_nan_columns) == sorted(['A', 'C', 'D']) - # as there are no categorical columns, we can make such an - # assertion. We only expect to drop the all nan columns - total_all_nan_columns = len(validator.all_nan_columns) - total_columns = len(validator.column_order) - assert total_columns - total_all_nan_columns == len(transformed_X_test.columns) # Columns with not all null values in the train split and # completely null on the test split. diff --git a/test/test_datasets/test_tabular_dataset.py b/test/test_datasets/test_tabular_dataset.py index 409e6bdec..d49e5dddd 100644 --- a/test/test_datasets/test_tabular_dataset.py +++ b/test/test_datasets/test_tabular_dataset.py @@ -25,7 +25,6 @@ def test_get_dataset_properties(backend, fit_dictionary_tabular): 'categorical_columns', 'numerical_columns', 'issparse', - 'is_small_preprocess', 'task_type', 'output_type', 'input_shape', diff --git a/test/test_pipeline/components/preprocessing/test_encoders.py b/test/test_pipeline/components/preprocessing/test_encoders.py index ac796291c..a6263d1a6 100644 --- a/test/test_pipeline/components/preprocessing/test_encoders.py +++ b/test/test_pipeline/components/preprocessing/test_encoders.py @@ -11,7 +11,7 @@ # TODO: fix in preprocessing PR -@unittest.skip("Skipping tests as preprocessing is not finalised") +# @unittest.skip("Skipping tests as preprocessing is not finalised") class TestEncoders(unittest.TestCase): def test_one_hot_encoder_no_unknown(self): diff --git a/test/test_pipeline/components/preprocessing/test_imputers.py b/test/test_pipeline/components/preprocessing/test_imputers.py index d2de6d7d3..e5f191d1d 100644 --- a/test/test_pipeline/components/preprocessing/test_imputers.py +++ b/test/test_pipeline/components/preprocessing/test_imputers.py @@ -12,7 +12,7 @@ # TODO: fix in preprocessing PR -@unittest.skip("Skipping tests as preprocessing is not finalised") +# @unittest.skip("Skipping tests as preprocessing is not finalised") class TestSimpleImputer(unittest.TestCase): def test_get_config_space(self): diff --git a/test/test_pipeline/components/preprocessing/test_scalers.py b/test/test_pipeline/components/preprocessing/test_scalers.py index cd41308fa..b98d9e055 100644 --- a/test/test_pipeline/components/preprocessing/test_scalers.py +++ b/test/test_pipeline/components/preprocessing/test_scalers.py @@ -13,7 +13,7 @@ # TODO: fix in preprocessing PR -@unittest.skip("Skipping tests as preprocessing is not finalised") +# @unittest.skip("Skipping tests as preprocessing is not finalised") class TestNormalizer(unittest.TestCase): def test_l2_norm(self): @@ -132,7 +132,7 @@ def test_max_norm(self): # TODO: fix in preprocessing PR -@unittest.skip("Skipping tests as preprocessing is not finalised") +# @unittest.skip("Skipping tests as preprocessing is not finalised") class TestMinMaxScaler(unittest.TestCase): def test_minmax_scaler(self): @@ -175,7 +175,7 @@ def test_minmax_scaler(self): # TODO: fix in preprocessing PR -@unittest.skip("Skipping tests as preprocessing is not finalised") +# @unittest.skip("Skipping tests as preprocessing is not finalised") class TestStandardScaler(unittest.TestCase): def test_standard_scaler(self): @@ -219,7 +219,7 @@ def test_standard_scaler(self): # TODO: fix in preprocessing PR -@unittest.skip("Skipping tests as preprocessing is not finalised") +# @unittest.skip("Skipping tests as preprocessing is not finalised") class TestNoneScaler(unittest.TestCase): def test_none_scaler(self): diff --git a/test/test_pipeline/components/preprocessing/test_tabular_column_transformer.py b/test/test_pipeline/components/preprocessing/test_tabular_column_transformer.py index d7a59383c..c4d8ccd50 100644 --- a/test/test_pipeline/components/preprocessing/test_tabular_column_transformer.py +++ b/test/test_pipeline/components/preprocessing/test_tabular_column_transformer.py @@ -14,13 +14,14 @@ # TODO: fix in preprocessing PR -@pytest.mark.skip("Skipping tests as preprocessing is not finalised") +# @pytest.mark.skip("Skipping tests as preprocessing is not finalised") @pytest.mark.parametrize("fit_dictionary_tabular", ['classification_numerical_only', 'classification_categorical_only', 'classification_numerical_and_categorical'], indirect=True) class TestTabularTransformer: def test_tabular_preprocess(self, fit_dictionary_tabular): pipeline = TabularPipeline(dataset_properties=fit_dictionary_tabular['dataset_properties']) + X_train = fit_dictionary_tabular['X_train'].copy() pipeline = pipeline.fit(fit_dictionary_tabular) X = pipeline.transform(fit_dictionary_tabular) column_transformer = X['tabular_transformer'] @@ -32,17 +33,17 @@ def test_tabular_preprocess(self, fit_dictionary_tabular): # as the later is not callable and runs into error in the compose transform assert isinstance(column_transformer, TabularColumnTransformer) - data = column_transformer.preprocessor.fit_transform(X['X_train']) + data = column_transformer.preprocessor.fit_transform(X_train) assert isinstance(data, np.ndarray) # Make sure no columns are unintentionally dropped after preprocessing if len(fit_dictionary_tabular['dataset_properties']["numerical_columns"]) == 0: categorical_pipeline = column_transformer.preprocessor.named_transformers_['categorical_pipeline'] - categorical_data = categorical_pipeline.transform(X['X_train']) + categorical_data = categorical_pipeline.transform(X_train) assert data.shape[1] == categorical_data.shape[1] elif len(fit_dictionary_tabular['dataset_properties']["categorical_columns"]) == 0: numerical_pipeline = column_transformer.preprocessor.named_transformers_['numerical_pipeline'] - numerical_data = numerical_pipeline.transform(X['X_train']) + numerical_data = numerical_pipeline.transform(X_train) assert data.shape[1] == numerical_data.shape[1] def test_sparse_data(self, fit_dictionary_tabular): diff --git a/test/test_pipeline/components/training/test_feature_data_loader.py b/test/test_pipeline/components/training/test_feature_data_loader.py index 7d4c9d80d..7e97494a4 100644 --- a/test/test_pipeline/components/training/test_feature_data_loader.py +++ b/test/test_pipeline/components/training/test_feature_data_loader.py @@ -9,13 +9,13 @@ class TestFeatureDataLoader(unittest.TestCase): - def test_build_transform_small_preprocess_true(self): + def test_build_transform(self): """ Makes sure a proper composition is created """ loader = FeatureDataLoader() - fit_dictionary = {'dataset_properties': {'is_small_preprocess': True}} + fit_dictionary = {'dataset_properties': {}} for thing in ['imputer', 'scaler', 'encoder']: fit_dictionary[thing] = [unittest.mock.Mock()] @@ -25,19 +25,3 @@ def test_build_transform_small_preprocess_true(self): # No preprocessing needed here as it was done before self.assertEqual(len(compose.transforms), 1) - - def test_build_transform_small_preprocess_false(self): - """ - Makes sure a proper composition is created - """ - loader = FeatureDataLoader() - - fit_dictionary = {'dataset_properties': {'is_small_preprocess': False}, - 'preprocess_transforms': [unittest.mock.Mock()]} - - compose = loader.build_transform(fit_dictionary, mode='train') - - self.assertIsInstance(compose, torchvision.transforms.Compose) - - # We expect the to tensor, the preproces transforms and the check_array - self.assertEqual(len(compose.transforms), 4) diff --git a/test/test_pipeline/components/training/test_training.py b/test/test_pipeline/components/training/test_training.py index a52a12148..7b1cf7730 100644 --- a/test/test_pipeline/components/training/test_training.py +++ b/test/test_pipeline/components/training/test_training.py @@ -93,11 +93,6 @@ def test_check_requirements(self): 'backend is needed to load the data from'): loader.fit(fit_dictionary) - # Then the is small fit - fit_dictionary.update({'backend': unittest.mock.Mock()}) - with self.assertRaisesRegex(ValueError, - 'is_small_pre-process is required to know if th'): - loader.fit(fit_dictionary) def test_fit_transform(self): """ Makes sure that fit and transform work as intended """