From 6c2e54d21ebea1c746b03fad600b1a2d655d8be5 Mon Sep 17 00:00:00 2001 From: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> Date: Wed, 9 Feb 2022 20:12:12 +0100 Subject: [PATCH] [feat] Add coalescer (#376) * [fix] Add check dataset in transform as well for test dataset, which does not require fit * [test] Migrate tests from the francisco's PR without modifications * [fix] Modify so that tests pass * [test] Increase the coverage --- autoPyTorch/configs/greedy_portfolio.json | 16 ++ .../coalescer/MinorityCoalescer.py | 44 +++ .../coalescer/NoCoalescer.py | 37 +++ .../coalescer/__init__.py | 254 ++++++++++++++++++ .../coalescer/base_coalescer.py | 33 +++ .../pipeline/tabular_classification.py | 4 + autoPyTorch/pipeline/tabular_regression.py | 4 + autoPyTorch/utils/implementations.py | 127 ++++++++- test/test_api/.tmp_api/runhistory.json | 9 + .../components/preprocessing/base.py | 2 + .../preprocessing/test_coalescer.py | 86 ++++++ test/test_utils/runhistory.json | 14 + test/test_utils/test_coalescer_transformer.py | 101 +++++++ 13 files changed, 730 insertions(+), 1 deletion(-) create mode 100644 autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/MinorityCoalescer.py create mode 100644 autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/NoCoalescer.py create mode 100644 autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/__init__.py create mode 100644 autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/base_coalescer.py create mode 100644 test/test_pipeline/components/preprocessing/test_coalescer.py create mode 100644 test/test_utils/test_coalescer_transformer.py diff --git a/autoPyTorch/configs/greedy_portfolio.json b/autoPyTorch/configs/greedy_portfolio.json index ffc5d98f5..bdcb45401 100644 --- a/autoPyTorch/configs/greedy_portfolio.json +++ b/autoPyTorch/configs/greedy_portfolio.json @@ -1,5 +1,6 @@ [{"data_loader:batch_size": 60, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "NoFeaturePreprocessor", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -30,6 +31,7 @@ "network_backbone:ShapedMLPBackbone:max_dropout": 0.023271935735825866}, {"data_loader:batch_size": 255, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "NoFeaturePreprocessor", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -63,6 +65,7 @@ "network_backbone:ShapedResNetBackbone:max_dropout": 0.7662454727603789}, {"data_loader:batch_size": 165, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "NoFeaturePreprocessor", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -93,6 +96,7 @@ "network_head:fully_connected:units_layer_1": 128}, {"data_loader:batch_size": 299, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "NoFeaturePreprocessor", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -124,6 +128,7 @@ "network_head:fully_connected:units_layer_1": 128}, {"data_loader:batch_size": 183, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "NoFeaturePreprocessor", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -157,6 +162,7 @@ "network_backbone:ShapedResNetBackbone:max_dropout": 0.27204101593048097}, {"data_loader:batch_size": 21, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "NoFeaturePreprocessor", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -185,6 +191,7 @@ "network_head:fully_connected:units_layer_1": 128}, {"data_loader:batch_size": 159, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "TruncatedSVD", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -214,6 +221,7 @@ "network_head:fully_connected:units_layer_1": 128}, {"data_loader:batch_size": 442, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "TruncatedSVD", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -246,6 +254,7 @@ "network_head:fully_connected:units_layer_1": 128}, {"data_loader:batch_size": 140, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "TruncatedSVD", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -278,6 +287,7 @@ "network_head:fully_connected:units_layer_1": 128}, {"data_loader:batch_size": 48, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "NoFeaturePreprocessor", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -305,6 +315,7 @@ "network_head:fully_connected:units_layer_1": 128}, {"data_loader:batch_size": 168, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "NoFeaturePreprocessor", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -337,6 +348,7 @@ "network_backbone:ShapedResNetBackbone:max_dropout": 0.8992826006547855}, {"data_loader:batch_size": 21, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "NoFeaturePreprocessor", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -365,6 +377,7 @@ "network_head:fully_connected:units_layer_1": 128}, {"data_loader:batch_size": 163, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "NoFeaturePreprocessor", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -397,6 +410,7 @@ "network_backbone:ShapedResNetBackbone:max_dropout": 0.6341848343636569}, {"data_loader:batch_size": 150, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "NoFeaturePreprocessor", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -430,6 +444,7 @@ "network_backbone:ShapedResNetBackbone:max_dropout": 0.7133813761319248}, {"data_loader:batch_size": 151, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "TruncatedSVD", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -459,6 +474,7 @@ "network_head:fully_connected:units_layer_1": 128}, {"data_loader:batch_size": 42, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "TruncatedSVD", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/MinorityCoalescer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/MinorityCoalescer.py new file mode 100644 index 000000000..69edfcbb6 --- /dev/null +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/MinorityCoalescer.py @@ -0,0 +1,44 @@ +from typing import Any, Dict, Optional, Union + +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import UniformFloatHyperparameter + +import numpy as np + +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.coalescer.base_coalescer import BaseCoalescer +from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter +from autoPyTorch.utils.implementations import MinorityCoalesceTransformer + + +class MinorityCoalescer(BaseCoalescer): + """Group together categories whose occurence is less than a specified min_frac """ + def __init__(self, min_frac: float, random_state: np.random.RandomState): + super().__init__() + self.min_frac = min_frac + self.random_state = random_state + + def fit(self, X: Dict[str, Any], y: Any = None) -> BaseCoalescer: + self.check_requirements(X, y) + self.preprocessor['categorical'] = MinorityCoalesceTransformer(min_frac=self.min_frac) + return self + + @staticmethod + def get_hyperparameter_search_space( + dataset_properties: Optional[Dict[str, Any]] = None, + min_frac: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter='min_frac', + value_range=(1e-4, 0.5), + default_value=1e-2, + ), + ) -> ConfigurationSpace: + + cs = ConfigurationSpace() + add_hyperparameter(cs, min_frac, UniformFloatHyperparameter) + return cs + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]: + return { + 'shortname': 'MinorityCoalescer', + 'name': 'MinorityCoalescer', + 'handles_sparse': False + } diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/NoCoalescer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/NoCoalescer.py new file mode 100644 index 000000000..fdc13dec6 --- /dev/null +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/NoCoalescer.py @@ -0,0 +1,37 @@ +from typing import Any, Dict, Optional, Union + +import numpy as np + +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.coalescer.base_coalescer import BaseCoalescer + + +class NoCoalescer(BaseCoalescer): + def __init__(self, random_state: np.random.RandomState): + super().__init__() + self.random_state = random_state + self._processing = False + + def fit(self, X: Dict[str, Any], y: Optional[Any] = None) -> BaseCoalescer: + """ + As no coalescing happens, only check the requirements. + + Args: + X (Dict[str, Any]): + fit dictionary + y (Optional[Any]): + Parameter to comply with scikit-learn API. Not used. + + Returns: + instance of self + """ + self.check_requirements(X, y) + + return self + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]: + return { + 'shortname': 'NoCoalescer', + 'name': 'NoCoalescer', + 'handles_sparse': True + } diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/__init__.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/__init__.py new file mode 100644 index 000000000..1139106ce --- /dev/null +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/__init__.py @@ -0,0 +1,254 @@ +import os +from collections import OrderedDict +from typing import Dict, List, Optional, Sequence + +import ConfigSpace.hyperparameters as CSH +from ConfigSpace.configuration_space import ConfigurationSpace + +from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType +from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice +from autoPyTorch.pipeline.components.base_component import ( + ThirdPartyComponents, + autoPyTorchComponent, + find_components, +) +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.coalescer.base_coalescer import BaseCoalescer +from autoPyTorch.utils.common import HyperparameterSearchSpace, HyperparameterValueType + + +coalescer_directory = os.path.split(__file__)[0] +_coalescer = find_components(__package__, + coalescer_directory, + BaseCoalescer) +_addons = ThirdPartyComponents(BaseCoalescer) + + +def add_coalescer(coalescer: BaseCoalescer) -> None: + _addons.add_component(coalescer) + + +class CoalescerChoice(autoPyTorchChoice): + """ + Allows for dynamically choosing coalescer component at runtime + """ + proc_name = "coalescer" + + def get_components(self) -> Dict[str, autoPyTorchComponent]: + """Returns the available coalescer components + + Args: + None + + Returns: + Dict[str, autoPyTorchComponent]: all BaseCoalescer components available + as choices for coalescer the categorical columns + """ + # TODO: Create `@property def components(): ...`. + components = OrderedDict() + components.update(_coalescer) + components.update(_addons.components) + return components + + @staticmethod + def _get_default_choice( + avail_components: Dict[str, autoPyTorchComponent], + include: List[str], + exclude: List[str], + defaults: List[str] = ['NoCoalescer', 'MinorityCoalescer'], + ) -> str: + # TODO: Make it a base method + for choice in defaults: + if choice in avail_components and choice in include and choice not in exclude: + return choice + else: + raise RuntimeError( + f"Available components is either not included in `include` {include} or " + f"included in `exclude` {exclude}" + ) + + def _update_config_space( + self, + component: CSH.Hyperparameter, + avail_components: Dict[str, autoPyTorchComponent], + dataset_properties: Dict[str, BaseDatasetPropertiesType] + ) -> None: + # TODO: Make it a base method + cs = ConfigurationSpace() + cs.add_hyperparameter(component) + + # add only child hyperparameters of early_preprocessor choices + for name in component.choices: + updates = self._get_search_space_updates(prefix=name) + func4cs = avail_components[name].get_hyperparameter_search_space + + # search space provides different args, so ignore it + component_config_space = func4cs(dataset_properties, **updates) # type:ignore[call-arg] + parent_hyperparameter = {'parent': component, 'value': name} + cs.add_configuration_space( + name, + component_config_space, + parent_hyperparameter=parent_hyperparameter + ) + + self.configuration_space = cs + + def _check_choices_in_update( + self, + choices_in_update: Sequence[HyperparameterValueType], + avail_components: Dict[str, autoPyTorchComponent] + ) -> None: + # TODO: Make it a base method + if not set(choices_in_update).issubset(avail_components): + raise ValueError( + f"The update for {self.__class__.__name__} is expected to be " + f"a subset of {avail_components}, but got {choices_in_update}" + ) + + def get_hyperparameter_search_space(self, + dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None, + default: Optional[str] = None, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None) -> ConfigurationSpace: + # TODO: Make it a base method + + if dataset_properties is None: + dataset_properties = dict() + + dataset_properties = {**self.dataset_properties, **dataset_properties} + + avail_cmps = self.get_available_components( + dataset_properties=dataset_properties, + include=include, + exclude=exclude + ) + + if len(avail_cmps) == 0: + raise ValueError(f"No {self.proc_name} found, please add {self.proc_name} to `include` argument") + + include = include if include is not None else list(avail_cmps.keys()) + exclude = exclude if exclude is not None else [] + if default is None: + default = self._get_default_choice(avail_cmps, include, exclude) + + updates = self._get_search_space_updates() + if "__choice__" in updates: + component = self._get_component_with_updates( + updates=updates, + avail_components=avail_cmps, + dataset_properties=dataset_properties + ) + else: + component = self._get_component_without_updates( + default=default, + include=include, + avail_components=avail_cmps, + dataset_properties=dataset_properties + ) + + self.dataset_properties = dataset_properties + self._update_config_space( + component=component, + avail_components=avail_cmps, + dataset_properties=dataset_properties + ) + return self.configuration_space + + def _check_dataset_properties(self, dataset_properties: Dict[str, BaseDatasetPropertiesType]) -> None: + """ + A mechanism in code to ensure the correctness of the dataset_properties + It recursively makes sure that the children and parent level requirements + are honored. + + Args: + dataset_properties: + """ + # TODO: Make it a base method + super()._check_dataset_properties(dataset_properties) + if any(key not in dataset_properties for key in ['categorical_columns', 'numerical_columns']): + raise ValueError("Dataset properties must contain information about the type of columns") + + def _get_component_with_updates( + self, + updates: Dict[str, HyperparameterSearchSpace], + avail_components: Dict[str, autoPyTorchComponent], + dataset_properties: Dict[str, BaseDatasetPropertiesType], + ) -> CSH.Hyperparameter: + # TODO: Make it a base method + choice_key = '__choice__' + choices_in_update = updates[choice_key].value_range + default_in_update = updates[choice_key].default_value + self._check_choices_in_update( + choices_in_update=choices_in_update, + avail_components=avail_components + ) + self._check_update_compatiblity(choices_in_update, dataset_properties) + return CSH.CategoricalHyperparameter(choice_key, choices_in_update, default_in_update) + + def _get_component_without_updates( + self, + avail_components: Dict[str, autoPyTorchComponent], + dataset_properties: Dict[str, BaseDatasetPropertiesType], + default: str, + include: List[str] + ) -> CSH.Hyperparameter: + """ + A method to get a hyperparameter information for the component. + This method is run when we do not get updates from _get_search_space_updates. + + Args: + avail_components (Dict[str, autoPyTorchComponent]): + Available components for this processing. + dataset_properties (Dict[str, BaseDatasetPropertiesType]): + The properties of the dataset. + default (str): + The default component for this processing. + include (List[str]): + The components to include for the auto-pytorch searching. + + Returns: + (CSH.Hyperparameter): + The hyperparameter information for this processing. + """ + # TODO: Make an abstract method with NotImplementedError + choice_key = '__choice__' + no_proc_key = 'NoCoalescer' + choices = list(avail_components.keys()) + + assert isinstance(dataset_properties['categorical_columns'], list) # mypy check + if len(dataset_properties['categorical_columns']) == 0: + # only no coalescer is compatible if the dataset has only numericals + default, choices = no_proc_key, [no_proc_key] + if no_proc_key not in include: + raise ValueError("Only no coalescer is compatible for a dataset with no categorical column") + + return CSH.CategoricalHyperparameter(choice_key, choices, default_value=default) + + def _check_update_compatiblity( + self, + choices_in_update: Sequence[HyperparameterValueType], + dataset_properties: Dict[str, BaseDatasetPropertiesType] + ) -> None: + """ + Check the compatibility of the updates for the components + in this processing given dataset properties. + For example, some processing is not compatible with datasets + with no numerical columns. + We would like to check such compatibility in this method. + + Args: + choices_in_update (Sequence[HyperparameterValueType]): + The choices of components in updates + dataset_properties (Dict[str, BaseDatasetPropertiesType]): + The properties of the dataset. + """ + # TODO: Make an abstract method with NotImplementedError + assert isinstance(dataset_properties['categorical_columns'], list) # mypy check + if len(dataset_properties['categorical_columns']) > 0: + # no restriction for update if dataset has categorical columns + return + + if 'NoCoalescer' not in choices_in_update or len(choices_in_update) != 1: + raise ValueError( + "Only no coalescer is compatible for a dataset with no categorical column, " + f"but got {choices_in_update}" + ) diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/base_coalescer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/base_coalescer.py new file mode 100644 index 000000000..b572f8343 --- /dev/null +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/base_coalescer.py @@ -0,0 +1,33 @@ +from typing import Any, Dict, List + +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.base_tabular_preprocessing import ( + autoPyTorchTabularPreprocessingComponent +) +from autoPyTorch.utils.common import FitRequirement + + +class BaseCoalescer(autoPyTorchTabularPreprocessingComponent): + def __init__(self) -> None: + super().__init__() + self._processing = True + self.add_fit_requirements([ + FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True), + FitRequirement('categories', (List,), user_defined=True, dataset_property=True) + ]) + + def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: + """ + Add the preprocessor to the provided fit dictionary `X`. + + Args: + X (Dict[str, Any]): fit dictionary in sklearn + + Returns: + X (Dict[str, Any]): the updated fit dictionary + """ + if self._processing and self.preprocessor['categorical'] is None: + # If we apply minority coalescer, we must have categorical preprocessor! + raise RuntimeError(f"fit() must be called before transform() on {self.__class__.__name__}") + + X.update({'coalescer': self.preprocessor}) + return X diff --git a/autoPyTorch/pipeline/tabular_classification.py b/autoPyTorch/pipeline/tabular_classification.py index 92dc764bb..720d0af64 100644 --- a/autoPyTorch/pipeline/tabular_classification.py +++ b/autoPyTorch/pipeline/tabular_classification.py @@ -19,6 +19,9 @@ from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.TabularColumnTransformer import ( TabularColumnTransformer ) +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.coalescer import ( + CoalescerChoice +) from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding import ( EncoderChoice ) @@ -310,6 +313,7 @@ def _get_pipeline_steps( steps.extend([ ("imputer", SimpleImputer(random_state=self.random_state)), ("variance_threshold", VarianceThreshold(random_state=self.random_state)), + ("coalescer", CoalescerChoice(default_dataset_properties, random_state=self.random_state)), ("encoder", EncoderChoice(default_dataset_properties, random_state=self.random_state)), ("scaler", ScalerChoice(default_dataset_properties, random_state=self.random_state)), ("feature_preprocessor", FeatureProprocessorChoice(default_dataset_properties, diff --git a/autoPyTorch/pipeline/tabular_regression.py b/autoPyTorch/pipeline/tabular_regression.py index daee7f74a..06da9cabb 100644 --- a/autoPyTorch/pipeline/tabular_regression.py +++ b/autoPyTorch/pipeline/tabular_regression.py @@ -19,6 +19,9 @@ from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.TabularColumnTransformer import ( TabularColumnTransformer ) +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.coalescer import ( + CoalescerChoice +) from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding import ( EncoderChoice ) @@ -260,6 +263,7 @@ def _get_pipeline_steps( steps.extend([ ("imputer", SimpleImputer(random_state=self.random_state)), ("variance_threshold", VarianceThreshold(random_state=self.random_state)), + ("coalescer", CoalescerChoice(default_dataset_properties, random_state=self.random_state)), ("encoder", EncoderChoice(default_dataset_properties, random_state=self.random_state)), ("scaler", ScalerChoice(default_dataset_properties, random_state=self.random_state)), ("feature_preprocessor", FeatureProprocessorChoice(default_dataset_properties, diff --git a/autoPyTorch/utils/implementations.py b/autoPyTorch/utils/implementations.py index a0b020622..4b699e3c3 100644 --- a/autoPyTorch/utils/implementations.py +++ b/autoPyTorch/utils/implementations.py @@ -1,7 +1,11 @@ -from typing import Any, Callable, Dict, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, Union import numpy as np +from scipy import sparse + +from sklearn.base import BaseEstimator, TransformerMixin + import torch @@ -59,3 +63,124 @@ def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray: @staticmethod def get_properties() -> Dict[str, Any]: return {'supported_losses': ['BCEWithLogitsLoss']} + + +class MinorityCoalesceTransformer(BaseEstimator, TransformerMixin): + """ Group together categories whose occurrence is less than a specified min_frac.""" + def __init__(self, min_frac: Optional[float] = None): + self.min_frac = min_frac + self._categories_to_coalesce: Optional[List[np.ndarray]] = None + + if self.min_frac is not None and (self.min_frac < 0 or self.min_frac > 1): + raise ValueError(f"min_frac for {self.__class__.__name__} must be in [0, 1], but got {min_frac}") + + def _check_dataset(self, X: Union[np.ndarray, sparse.csr_matrix]) -> None: + """ + When transforming datasets, we modify values to: + * 0 for nan values + * -1 for unknown values + * -2 for values to be coalesced + For this reason, we need to check whether datasets have values + smaller than -2 to avoid mis-transformation. + Note that zero-imputation is the default setting in SimpleImputer of sklearn. + + Args: + X (np.ndarray): + The input features from the user, likely transformed by an encoder and imputator. + """ + X_data = X.data if sparse.issparse(X) else X + if np.nanmin(X_data) <= -2: + raise ValueError("The categoricals in input features for MinorityCoalesceTransformer " + "cannot have integers smaller than -2.") + + @staticmethod + def _get_column_data( + X: Union[np.ndarray, sparse.csr_matrix], + col_idx: int, + is_sparse: bool + ) -> Union[np.ndarray, sparse.csr_matrix]: + """ + Args: + X (Union[np.ndarray, sparse.csr_matrix]): + The feature tensor with only categoricals. + col_idx (int): + The index of the column to get the data. + is_sparse (bool): + Whether the tensor is sparse or not. + + Return: + col_data (Union[np.ndarray, sparse.csr_matrix]): + The column data of the tensor. + """ + + if is_sparse: + assert not isinstance(X, np.ndarray) # mypy check + indptr_start = X.indptr[col_idx] + indptr_end = X.indptr[col_idx + 1] + col_data = X.data[indptr_start:indptr_end] + else: + col_data = X[:, col_idx] + + return col_data + + def fit(self, X: Union[np.ndarray, sparse.csr_matrix], + y: Optional[np.ndarray] = None) -> 'MinorityCoalesceTransformer': + """ + Train the estimator to identify low frequency classes on the input train data. + + Args: + X (Union[np.ndarray, sparse.csr_matrix]): + The input features from the user, likely transformed by an encoder and imputator. + y (Optional[np.ndarray]): + Optional labels for the given task, not used by this estimator. + """ + self._check_dataset(X) + n_instances, n_features = X.shape + + if self.min_frac is None: + self._categories_to_coalesce = [np.array([]) for _ in range(n_features)] + return self + + categories_to_coalesce: List[np.ndarray] = [] + is_sparse = sparse.issparse(X) + for col in range(n_features): + col_data = self._get_column_data(X=X, col_idx=col, is_sparse=is_sparse) + unique_vals, counts = np.unique(col_data, return_counts=True) + frac = counts / n_instances + categories_to_coalesce.append(unique_vals[frac < self.min_frac]) + + self._categories_to_coalesce = categories_to_coalesce + return self + + def transform( + self, + X: Union[np.ndarray, sparse.csr_matrix] + ) -> Union[np.ndarray, sparse.csr_matrix]: + """ + Coalesce categories with low frequency in X. + + Args: + X (Union[np.ndarray, sparse.csr_matrix]): + The input features from the user, likely transformed by an encoder and imputator. + """ + self._check_dataset(X) + + if self._categories_to_coalesce is None: + raise RuntimeError("fit() must be called before transform()") + + if self.min_frac is None: + return X + + n_features = X.shape[1] + is_sparse = sparse.issparse(X) + + for col in range(n_features): + # -2 stands coalesced. For more details, see the doc in _check_dataset + col_data = self._get_column_data(X=X, col_idx=col, is_sparse=is_sparse) + mask = np.isin(col_data, self._categories_to_coalesce[col]) + col_data[mask] = -2 + + return X + + def fit_transform(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> np.ndarray: + return self.fit(X, y).transform(X) diff --git a/test/test_api/.tmp_api/runhistory.json b/test/test_api/.tmp_api/runhistory.json index 6f61e1395..28c0cbd32 100644 --- a/test/test_api/.tmp_api/runhistory.json +++ b/test/test_api/.tmp_api/runhistory.json @@ -705,6 +705,7 @@ "1": { "data_loader:batch_size": 64, "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "NoFeaturePreprocessor", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "ReduceLROnPlateau", @@ -737,6 +738,7 @@ "2": { "data_loader:batch_size": 101, "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "PowerTransformer", "imputer:numerical_strategy": "most_frequent", "lr_scheduler:__choice__": "CyclicLR", @@ -801,6 +803,7 @@ "3": { "data_loader:batch_size": 242, "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "RandomKitchenSinks", "imputer:numerical_strategy": "median", "lr_scheduler:__choice__": "NoScheduler", @@ -831,6 +834,7 @@ "4": { "data_loader:batch_size": 115, "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "Nystroem", "imputer:numerical_strategy": "median", "lr_scheduler:__choice__": "CosineAnnealingLR", @@ -864,6 +868,7 @@ "5": { "data_loader:batch_size": 185, "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "RandomKitchenSinks", "imputer:numerical_strategy": "median", "lr_scheduler:__choice__": "ReduceLROnPlateau", @@ -904,6 +909,7 @@ "6": { "data_loader:batch_size": 95, "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "RandomKitchenSinks", "imputer:numerical_strategy": "most_frequent", "lr_scheduler:__choice__": "ExponentialLR", @@ -937,6 +943,7 @@ "7": { "data_loader:batch_size": 119, "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "Nystroem", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "StepLR", @@ -979,6 +986,7 @@ "8": { "data_loader:batch_size": 130, "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "PolynomialFeatures", "imputer:numerical_strategy": "median", "lr_scheduler:__choice__": "CyclicLR", @@ -1032,6 +1040,7 @@ "9": { "data_loader:batch_size": 137, "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "Nystroem", "imputer:numerical_strategy": "mean", "lr_scheduler:__choice__": "CosineAnnealingLR", diff --git a/test/test_pipeline/components/preprocessing/base.py b/test/test_pipeline/components/preprocessing/base.py index 35f6ed271..a2705e19b 100644 --- a/test/test_pipeline/components/preprocessing/base.py +++ b/test/test_pipeline/components/preprocessing/base.py @@ -3,6 +3,7 @@ from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.TabularColumnTransformer import \ TabularColumnTransformer +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.coalescer import CoalescerChoice from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding import EncoderChoice from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.imputation.SimpleImputer import SimpleImputer from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling import ScalerChoice @@ -31,6 +32,7 @@ def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, Any]], steps.extend([ ("imputer", SimpleImputer()), ("variance_threshold", VarianceThreshold()), + ("coalescer", CoalescerChoice(default_dataset_properties)), ("encoder", EncoderChoice(default_dataset_properties)), ("scaler", ScalerChoice(default_dataset_properties)), ("tabular_transformer", TabularColumnTransformer()), diff --git a/test/test_pipeline/components/preprocessing/test_coalescer.py b/test/test_pipeline/components/preprocessing/test_coalescer.py new file mode 100644 index 000000000..811cf8b6e --- /dev/null +++ b/test/test_pipeline/components/preprocessing/test_coalescer.py @@ -0,0 +1,86 @@ +import copy +import unittest + +import numpy as np + +import pytest + +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.coalescer import ( + CoalescerChoice +) +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.coalescer.MinorityCoalescer import ( + MinorityCoalescer +) + + +def test_transform_before_fit(): + with pytest.raises(RuntimeError): + mc = MinorityCoalescer(min_frac=None, random_state=np.random.RandomState()) + mc.transform(np.random.random((4, 4))) + + +class TestCoalescerChoice(unittest.TestCase): + def test_raise_error_in_check_update_compatiblity(self): + dataset_properties = {'numerical_columns': [], 'categorical_columns': []} + cc = CoalescerChoice(dataset_properties) + choices = ["NoCoescer"] # component name with typo + with pytest.raises(ValueError): + # raise error because no categorical columns, but choices do not have no coalescer + cc._check_update_compatiblity(choices_in_update=choices, dataset_properties=dataset_properties) + + def test_raise_error_in_get_component_without_updates(self): + dataset_properties = {'numerical_columns': [], 'categorical_columns': []} + cc = CoalescerChoice(dataset_properties) + with pytest.raises(ValueError): + # raise error because no categorical columns, but choices do not have no coalescer + cc._get_component_without_updates( + avail_components={}, + dataset_properties=dataset_properties, + default="", + include=[] + ) + + def test_get_set_config_space(self): + """Make sure that we can setup a valid choice in the Coalescer + choice""" + dataset_properties = {'numerical_columns': list(range(4)), 'categorical_columns': [5]} + coalescer_choice = CoalescerChoice(dataset_properties) + cs = coalescer_choice.get_hyperparameter_search_space() + + # Make sure that all hyperparameters are part of the search space + self.assertListEqual( + sorted(cs.get_hyperparameter('__choice__').choices), + sorted(list(coalescer_choice.get_components().keys())) + ) + + # Make sure we can properly set some random configs + # Whereas just one iteration will make sure the algorithm works, + # doing five iterations increase the confidence. We will be able to + # catch component specific crashes + for _ in range(5): + config = cs.sample_configuration() + config_dict = copy.deepcopy(config.get_dictionary()) + coalescer_choice.set_hyperparameters(config) + + self.assertEqual(coalescer_choice.choice.__class__, + coalescer_choice.get_components()[config_dict['__choice__']]) + + # Then check the choice configuration + selected_choice = config_dict.pop('__choice__', None) + for key, value in config_dict.items(): + # Remove the selected_choice string from the parameter + # so we can query in the object for it + key = key.replace(selected_choice + ':', '') + self.assertIn(key, vars(coalescer_choice.choice)) + self.assertEqual(value, coalescer_choice.choice.__dict__[key]) + + def test_only_numerical(self): + dataset_properties = {'numerical_columns': list(range(4)), 'categorical_columns': []} + + chooser = CoalescerChoice(dataset_properties) + configspace = chooser.get_hyperparameter_search_space().sample_configuration().get_dictionary() + self.assertEqual(configspace['__choice__'], 'NoCoalescer') + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_utils/runhistory.json b/test/test_utils/runhistory.json index 37e499664..a2c3658a8 100755 --- a/test/test_utils/runhistory.json +++ b/test/test_utils/runhistory.json @@ -1133,6 +1133,7 @@ "1": { "data_loader:batch_size": 64, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "NoFeaturePreprocessor", "imputer:categorical_strategy": "most_frequent", "imputer:numerical_strategy": "mean", @@ -1166,6 +1167,7 @@ "2": { "data_loader:batch_size": 142, "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "PowerTransformer", "imputer:categorical_strategy": "constant_!missing!", "imputer:numerical_strategy": "median", @@ -1203,6 +1205,7 @@ "3": { "data_loader:batch_size": 246, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "PowerTransformer", "imputer:categorical_strategy": "constant_!missing!", "imputer:numerical_strategy": "most_frequent", @@ -1281,6 +1284,7 @@ "4": { "data_loader:batch_size": 269, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "PowerTransformer", "imputer:categorical_strategy": "constant_!missing!", "imputer:numerical_strategy": "median", @@ -1324,6 +1328,7 @@ "5": { "data_loader:batch_size": 191, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "RandomKitchenSinks", "imputer:categorical_strategy": "constant_!missing!", "imputer:numerical_strategy": "most_frequent", @@ -1373,6 +1378,7 @@ "6": { "data_loader:batch_size": 53, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "PowerTransformer", "imputer:categorical_strategy": "constant_!missing!", "imputer:numerical_strategy": "median", @@ -1429,6 +1435,7 @@ "7": { "data_loader:batch_size": 232, "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "RandomKitchenSinks", "imputer:categorical_strategy": "most_frequent", "imputer:numerical_strategy": "most_frequent", @@ -1506,6 +1513,7 @@ "8": { "data_loader:batch_size": 164, "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "NoFeaturePreprocessor", "imputer:categorical_strategy": "most_frequent", "imputer:numerical_strategy": "mean", @@ -1540,6 +1548,7 @@ "9": { "data_loader:batch_size": 94, "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "PolynomialFeatures", "imputer:categorical_strategy": "most_frequent", "imputer:numerical_strategy": "mean", @@ -1589,6 +1598,7 @@ "10": { "data_loader:batch_size": 70, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "PowerTransformer", "imputer:categorical_strategy": "most_frequent", "imputer:numerical_strategy": "constant_zero", @@ -1637,6 +1647,7 @@ "11": { "data_loader:batch_size": 274, "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "RandomKitchenSinks", "imputer:categorical_strategy": "constant_!missing!", "imputer:numerical_strategy": "mean", @@ -1675,6 +1686,7 @@ "12": { "data_loader:batch_size": 191, "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "NoFeaturePreprocessor", "imputer:categorical_strategy": "constant_!missing!", "imputer:numerical_strategy": "median", @@ -1730,6 +1742,7 @@ "13": { "data_loader:batch_size": 35, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "PowerTransformer", "imputer:categorical_strategy": "most_frequent", "imputer:numerical_strategy": "most_frequent", @@ -1766,6 +1779,7 @@ "14": { "data_loader:batch_size": 154, "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", "feature_preprocessor:__choice__": "KernelPCA", "imputer:categorical_strategy": "most_frequent", "imputer:numerical_strategy": "mean", diff --git a/test/test_utils/test_coalescer_transformer.py b/test/test_utils/test_coalescer_transformer.py new file mode 100644 index 000000000..eccd6b7bd --- /dev/null +++ b/test/test_utils/test_coalescer_transformer.py @@ -0,0 +1,101 @@ +import numpy as np + +import pytest + +import scipy.sparse + +from autoPyTorch.utils.implementations import MinorityCoalesceTransformer + + +@pytest.fixture +def X1(): + # Generates an array with categories 3, 4, 5, 6, 7 and occurences of 30%, + # 30%, 30%, 5% and 5% respectively + X = np.vstack(( + np.ones((30, 10)) * 3, + np.ones((30, 10)) * 4, + np.ones((30, 10)) * 5, + np.ones((5, 10)) * 6, + np.ones((5, 10)) * 7, + )) + for col in range(X.shape[1]): + np.random.shuffle(X[:, col]) + return X + + +@pytest.fixture +def X2(): + # Generates an array with categories 3, 4, 5, 6, 7 and occurences of 5%, + # 5%, 5%, 35% and 50% respectively + X = np.vstack(( + np.ones((5, 10)) * 3, + np.ones((5, 10)) * 4, + np.ones((5, 10)) * 5, + np.ones((35, 10)) * 6, + np.ones((50, 10)) * 7, + )) + for col in range(X.shape[1]): + np.random.shuffle(X[:, col]) + return X + + +def test_default(X1): + X = X1 + X_copy = np.copy(X) + Y = MinorityCoalesceTransformer().fit_transform(X) + np.testing.assert_array_almost_equal(Y, X_copy) + # Assert no copies were made + assert id(X) == id(Y) + + +def test_coalesce_10_percent(X1): + X = X1 + Y = MinorityCoalesceTransformer(min_frac=.1).fit_transform(X) + for col in range(Y.shape[1]): + hist = np.histogram(Y[:, col], bins=np.arange(-2, 7)) + np.testing.assert_array_almost_equal(hist[0], [10, 0, 0, 0, 0, 30, 30, 30]) + # Assert no copies were made + assert id(X) == id(Y) + + +def test_coalesce_10_percent_sparse(X1): + X = scipy.sparse.csc_matrix(X1) + Y = MinorityCoalesceTransformer(min_frac=.1).fit_transform(X) + # Assert no copies were made + assert id(X) == id(Y) + Y = Y.todense() + for col in range(Y.shape[1]): + hist = np.histogram(Y[:, col], bins=np.arange(-2, 7)) + np.testing.assert_array_almost_equal(hist[0], [10, 0, 0, 0, 0, 30, 30, 30]) + + +def test_invalid_X(X1): + X = X1 - 5 + with pytest.raises(ValueError): + MinorityCoalesceTransformer().fit_transform(X) + + +@pytest.mark.parametrize("min_frac", [-0.1, 1.1]) +def test_invalid_min_frac(min_frac): + with pytest.raises(ValueError): + MinorityCoalesceTransformer(min_frac=min_frac) + + +def test_transform_before_fit(X1): + with pytest.raises(RuntimeError): + MinorityCoalesceTransformer().transform(X1) + + +def test_transform_after_fit(X1, X2): + # On both X_fit and X_transf, the categories 3, 4, 5, 6, 7 are present. + X_fit = X1 # Here categories 3, 4, 5 have ocurrence above 10% + X_transf = X2 # Here it is the opposite, just categs 6 and 7 are above 10% + + mc = MinorityCoalesceTransformer(min_frac=.1).fit(X_fit) + + # transform() should coalesce categories as learned during fit. + # Category distribution in X_transf should be irrelevant. + Y = mc.transform(X_transf) + for col in range(Y.shape[1]): + hist = np.histogram(Y[:, col], bins=np.arange(-2, 7)) + np.testing.assert_array_almost_equal(hist[0], [85, 0, 0, 0, 0, 5, 5, 5])