From 09fdc0dcd9c0d9227bfd2d43a638d61f08e3e3d3 Mon Sep 17 00:00:00 2001 From: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> Date: Sat, 16 Jul 2022 17:31:36 +0200 Subject: [PATCH] [ADD] Calculate memory of dataset after one hot encoding (pytorch embedding) (#437) * add updates for apt1.0+reg_cocktails * debug loggers for checking data and network memory usage * add support for pandas, test for data passing, remove debug loggers * remove unwanted changes * : * Adjust formula to account for embedding columns * Apply suggestions from code review Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> * remove unwanted additions * Update autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> --- autoPyTorch/constants.py | 2 + autoPyTorch/data/tabular_validator.py | 2 + autoPyTorch/data/utils.py | 53 ++++++++++++++++--- .../TabularColumnTransformer.py | 1 + .../column_splitting/ColumnSplitter.py | 4 +- .../encoding/OneHotEncoder.py | 3 +- test/test_api/test_api.py | 4 +- test/test_data/test_utils.py | 21 ++++++-- test/test_data/test_validation.py | 29 ++++++++-- 9 files changed, 98 insertions(+), 21 deletions(-) diff --git a/autoPyTorch/constants.py b/autoPyTorch/constants.py index bfd56d27f..154d562ac 100644 --- a/autoPyTorch/constants.py +++ b/autoPyTorch/constants.py @@ -78,3 +78,5 @@ # To avoid that we get a sequence that is too long to be fed to a network MAX_WINDOW_SIZE_BASE = 500 + +MIN_CATEGORIES_FOR_EMBEDDING_MAX = 7 diff --git a/autoPyTorch/data/tabular_validator.py b/autoPyTorch/data/tabular_validator.py index 0f6f89e1c..0735d49b4 100644 --- a/autoPyTorch/data/tabular_validator.py +++ b/autoPyTorch/data/tabular_validator.py @@ -111,6 +111,8 @@ def _compress_dataset( y=y, is_classification=self.is_classification, random_state=self.seed, + categorical_columns=self.feature_validator.categorical_columns, + n_categories_per_cat_column=self.feature_validator.num_categories_per_col, **self.dataset_compression # type: ignore [arg-type] ) self._reduced_dtype = dict(X.dtypes) if is_dataframe else X.dtype diff --git a/autoPyTorch/data/utils.py b/autoPyTorch/data/utils.py index 20ad5612e..2a44dd5c2 100644 --- a/autoPyTorch/data/utils.py +++ b/autoPyTorch/data/utils.py @@ -25,6 +25,7 @@ from sklearn.utils import _approximate_mode, check_random_state from sklearn.utils.validation import _num_samples, check_array +from autoPyTorch.constants import MIN_CATEGORIES_FOR_EMBEDDING_MAX from autoPyTorch.data.base_target_validator import SupportedTargetTypes from autoPyTorch.utils.common import ispandas @@ -459,8 +460,8 @@ def _subsample_by_indices( return X, y -def megabytes(arr: DatasetCompressionInputType) -> float: - +def get_raw_memory_usage(arr: DatasetCompressionInputType) -> float: + memory_in_bytes: float if isinstance(arr, np.ndarray): memory_in_bytes = arr.nbytes elif issparse(arr): @@ -470,8 +471,43 @@ def megabytes(arr: DatasetCompressionInputType) -> float: else: raise ValueError(f"Unrecognised data type of X, expected data type to " f"be in (np.ndarray, spmatrix, pd.DataFrame) but got :{type(arr)}") + return memory_in_bytes + + +def get_approximate_mem_usage_in_mb( + arr: DatasetCompressionInputType, + categorical_columns: List, + n_categories_per_cat_column: Optional[List[int]] = None +) -> float: + + err_msg = "Value number of categories per categorical is required when the data has categorical columns" + if ispandas(arr): + arr_dtypes = arr.dtypes.to_dict() + multipliers = [dtype.itemsize for col, dtype in arr_dtypes.items() if col not in categorical_columns] + if len(categorical_columns) > 0: + if n_categories_per_cat_column is None: + raise ValueError(err_msg) + for col, num_cat in zip(categorical_columns, n_categories_per_cat_column): + if num_cat < MIN_CATEGORIES_FOR_EMBEDDING_MAX: + multipliers.append(num_cat * arr_dtypes[col].itemsize) + else: + multipliers.append(arr_dtypes[col].itemsize) + size_one_row = sum(multipliers) + + elif isinstance(arr, (np.ndarray, spmatrix)): + n_cols = arr.shape[-1] - len(categorical_columns) + multiplier = arr.dtype.itemsize + if len(categorical_columns) > 0: + if n_categories_per_cat_column is None: + raise ValueError(err_msg) + # multiply num categories with the size of the column to capture memory after one hot encoding + n_cols += sum(num_cat if num_cat < MIN_CATEGORIES_FOR_EMBEDDING_MAX else 1 for num_cat in n_categories_per_cat_column) + size_one_row = n_cols * multiplier + else: + raise ValueError(f"Unrecognised data type of X, expected data type to " + f"be in (np.ndarray, spmatrix, pd.DataFrame), but got :{type(arr)}") - return float(memory_in_bytes / (2**20)) + return float(arr.shape[0] * size_one_row / (2**20)) def reduce_dataset_size_if_too_large( @@ -479,10 +515,13 @@ def reduce_dataset_size_if_too_large( memory_allocation: Union[int, float], is_classification: bool, random_state: Union[int, np.random.RandomState], + categorical_columns: List, + n_categories_per_cat_column: Optional[List[int]] = None, y: Optional[SupportedTargetTypes] = None, methods: List[str] = ['precision', 'subsample'], ) -> DatasetCompressionInputType: - f""" Reduces the size of the dataset if it's too close to the memory limit. + f""" + Reduces the size of the dataset if it's too close to the memory limit. Follows the order of the operations passed in and retains the type of its input. @@ -513,7 +552,6 @@ def reduce_dataset_size_if_too_large( Reduce the amount of samples of the dataset such that it fits into the allocated memory. Ensures stratification and that unique labels are present - memory_allocation (Union[int, float]): The amount of memory to allocate to the dataset. It should specify an absolute amount. @@ -524,7 +562,7 @@ def reduce_dataset_size_if_too_large( """ for method in methods: - if megabytes(X) <= memory_allocation: + if get_approximate_mem_usage_in_mb(X, categorical_columns, n_categories_per_cat_column) <= memory_allocation: break if method == 'precision': @@ -540,7 +578,8 @@ def reduce_dataset_size_if_too_large( # into the allocated memory, we subsample it so that it does n_samples_before = X.shape[0] - sample_percentage = memory_allocation / megabytes(X) + sample_percentage = memory_allocation / get_approximate_mem_usage_in_mb( + X, categorical_columns, n_categories_per_cat_column) # NOTE: type ignore # diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py index 6b38b4650..48f40e9fe 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py @@ -24,6 +24,7 @@ def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = N self.add_fit_requirements([ FitRequirement('numerical_columns', (List,), user_defined=True, dataset_property=True), FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True)]) + def get_column_transformer(self) -> ColumnTransformer: """ diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/column_splitting/ColumnSplitter.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/column_splitting/ColumnSplitter.py index eeca9fdc4..437198d9e 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/column_splitting/ColumnSplitter.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/column_splitting/ColumnSplitter.py @@ -7,7 +7,7 @@ import numpy as np - +from autoPyTorch.constants import MIN_CATEGORIES_FOR_EMBEDDING_MAX from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.base_tabular_preprocessing import \ autoPyTorchTabularPreprocessingComponent @@ -72,7 +72,7 @@ def get_hyperparameter_search_space( dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None, min_categories_for_embedding: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="min_categories_for_embedding", - value_range=(3, 7), + value_range=(3, MIN_CATEGORIES_FOR_EMBEDDING_MAX), default_value=3, log=True), ) -> ConfigurationSpace: diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/OneHotEncoder.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/OneHotEncoder.py index 80cf3f748..4f8878615 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/OneHotEncoder.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/OneHotEncoder.py @@ -24,7 +24,8 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEncoder: # It is safer to have the OHE produce a 0 array than to crash a good configuration categories='auto', sparse=False, - handle_unknown='ignore') + handle_unknown='ignore', + dtype=np.float32) return self @staticmethod diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index 3e8847110..f71ad3f5f 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -682,10 +682,10 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular): del estimator -@pytest.skip("Fix with new portfolio PR") @unittest.mock.patch('autoPyTorch.evaluation.tae.eval_train_function', new=dummy_eval_train_function) @pytest.mark.parametrize('openml_id', (40981, )) +@pytest.mark.skip(reason="Fix with new portfolio PR") def test_portfolio_selection(openml_id, backend, n_samples): # Get the data and check that contents of data-manager make sense @@ -725,7 +725,7 @@ def test_portfolio_selection(openml_id, backend, n_samples): assert any(successful_config in portfolio_configs for successful_config in successful_configs) -@pytest.skip("Fix with new portfolio PR") +@pytest.mark.skip(reason="Fix with new portfolio PR") @unittest.mock.patch('autoPyTorch.evaluation.tae.eval_train_function', new=dummy_eval_train_function) @pytest.mark.parametrize('openml_id', (40981, )) diff --git a/test/test_data/test_utils.py b/test/test_data/test_utils.py index 4269c4e5f..6228740b0 100644 --- a/test/test_data/test_utils.py +++ b/test/test_data/test_utils.py @@ -25,7 +25,7 @@ from autoPyTorch.data.utils import ( default_dataset_compression_arg, get_dataset_compression_mapping, - megabytes, + get_raw_memory_usage, reduce_dataset_size_if_too_large, reduce_precision, subsample, @@ -45,13 +45,14 @@ def test_reduce_dataset_if_too_large(openmlid, as_frame, n_samples): X.copy(), y=y.copy(), is_classification=True, + categorical_columns=[], random_state=1, - memory_allocation=0.001) + memory_allocation=0.01) assert X_converted.shape[0] < X.shape[0] assert y_converted.shape[0] < y.shape[0] - assert megabytes(X_converted) < megabytes(X) + assert get_raw_memory_usage(X_converted) < get_raw_memory_usage(X) @pytest.mark.parametrize("X", [np.asarray([[1, 1, 1]] * 30)]) @@ -211,8 +212,18 @@ def test_unsupported_errors(): ['a', 'b', 'c', 'a', 'b', 'c'], ['a', 'b', 'd', 'r', 'b', 'c']]) with pytest.raises(ValueError, match=r'X.dtype = .*'): - reduce_dataset_size_if_too_large(X, is_classification=True, random_state=1, memory_allocation=0) + reduce_dataset_size_if_too_large( + X, + is_classification=True, + categorical_columns=[], + random_state=1, + memory_allocation=0) X = [[1, 2], [2, 3]] with pytest.raises(ValueError, match=r'Unrecognised data type of X, expected data type to be in .*'): - reduce_dataset_size_if_too_large(X, is_classification=True, random_state=1, memory_allocation=0) + reduce_dataset_size_if_too_large( + X, + is_classification=True, + categorical_columns=[], + random_state=1, + memory_allocation=0) diff --git a/test/test_data/test_validation.py b/test/test_data/test_validation.py index af46be55f..b6f05f7ba 100644 --- a/test/test_data/test_validation.py +++ b/test/test_data/test_validation.py @@ -8,7 +8,8 @@ import sklearn.model_selection from autoPyTorch.data.tabular_validator import TabularInputValidator -from autoPyTorch.data.utils import megabytes +from autoPyTorch.data.utils import get_approximate_mem_usage_in_mb +from autoPyTorch.utils.common import ispandas @pytest.mark.parametrize('openmlid', [2, 40975, 40984]) @@ -148,16 +149,36 @@ def test_featurevalidator_dataset_compression(input_data_featuretest): X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( input_data_featuretest, input_data_targets, test_size=0.1, random_state=1) validator = TabularInputValidator( - dataset_compression={'memory_allocation': 0.8 * megabytes(X_train), 'methods': ['precision', 'subsample']} + dataset_compression={ + 'memory_allocation': 0.8 * get_approximate_mem_usage_in_mb(X_train, [], None), + 'methods': ['precision', 'subsample']} ) validator.fit(X_train=X_train, y_train=y_train) transformed_X_train, _ = validator.transform(X_train.copy(), y_train.copy()) + if ispandas(X_train): + # input validator converts transformed_X_train to numpy and the cat columns are chosen as column indices + columns = X_train.columns + categorical_columns = [columns[col] for col in validator.feature_validator.categorical_columns] + else: + categorical_columns = validator.feature_validator.categorical_columns + assert validator._reduced_dtype is not None - assert megabytes(transformed_X_train) < megabytes(X_train) + assert get_approximate_mem_usage_in_mb( + transformed_X_train, + validator.feature_validator.categorical_columns, + validator.feature_validator.num_categories_per_col + ) < get_approximate_mem_usage_in_mb( + X_train, categorical_columns, validator.feature_validator.num_categories_per_col) transformed_X_test, _ = validator.transform(X_test.copy(), y_test.copy()) - assert megabytes(transformed_X_test) < megabytes(X_test) + assert get_approximate_mem_usage_in_mb( + transformed_X_test, + validator.feature_validator.categorical_columns, + validator.feature_validator.num_categories_per_col + ) < get_approximate_mem_usage_in_mb( + X_test, categorical_columns, validator.feature_validator.num_categories_per_col) + if hasattr(transformed_X_train, 'iloc'): assert all(transformed_X_train.dtypes == transformed_X_test.dtypes) assert all(transformed_X_train.dtypes == validator._precision)