Skip to content

Commit

Permalink
[ADD] Calculate memory of dataset after one hot encoding (pytorch emb…
Browse files Browse the repository at this point in the history
…edding) (#437)

* add updates for apt1.0+reg_cocktails

* debug loggers for checking data and network memory usage

* add support for pandas, test for data passing, remove debug loggers

* remove unwanted changes

* :

* Adjust formula to account for embedding columns

* Apply suggestions from code review

Co-authored-by: nabenabe0928 <[email protected]>

* remove unwanted additions

* Update autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py

Co-authored-by: nabenabe0928 <[email protected]>
  • Loading branch information
ravinkohli and nabenabe0928 committed Oct 25, 2022
1 parent 7567d26 commit 09fdc0d
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 21 deletions.
2 changes: 2 additions & 0 deletions autoPyTorch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,5 @@

# To avoid that we get a sequence that is too long to be fed to a network
MAX_WINDOW_SIZE_BASE = 500

MIN_CATEGORIES_FOR_EMBEDDING_MAX = 7
2 changes: 2 additions & 0 deletions autoPyTorch/data/tabular_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def _compress_dataset(
y=y,
is_classification=self.is_classification,
random_state=self.seed,
categorical_columns=self.feature_validator.categorical_columns,
n_categories_per_cat_column=self.feature_validator.num_categories_per_col,
**self.dataset_compression # type: ignore [arg-type]
)
self._reduced_dtype = dict(X.dtypes) if is_dataframe else X.dtype
Expand Down
53 changes: 46 additions & 7 deletions autoPyTorch/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sklearn.utils import _approximate_mode, check_random_state
from sklearn.utils.validation import _num_samples, check_array

from autoPyTorch.constants import MIN_CATEGORIES_FOR_EMBEDDING_MAX
from autoPyTorch.data.base_target_validator import SupportedTargetTypes
from autoPyTorch.utils.common import ispandas

Expand Down Expand Up @@ -459,8 +460,8 @@ def _subsample_by_indices(
return X, y


def megabytes(arr: DatasetCompressionInputType) -> float:

def get_raw_memory_usage(arr: DatasetCompressionInputType) -> float:
memory_in_bytes: float
if isinstance(arr, np.ndarray):
memory_in_bytes = arr.nbytes
elif issparse(arr):
Expand All @@ -470,19 +471,57 @@ def megabytes(arr: DatasetCompressionInputType) -> float:
else:
raise ValueError(f"Unrecognised data type of X, expected data type to "
f"be in (np.ndarray, spmatrix, pd.DataFrame) but got :{type(arr)}")
return memory_in_bytes


def get_approximate_mem_usage_in_mb(
arr: DatasetCompressionInputType,
categorical_columns: List,
n_categories_per_cat_column: Optional[List[int]] = None
) -> float:

err_msg = "Value number of categories per categorical is required when the data has categorical columns"
if ispandas(arr):
arr_dtypes = arr.dtypes.to_dict()
multipliers = [dtype.itemsize for col, dtype in arr_dtypes.items() if col not in categorical_columns]
if len(categorical_columns) > 0:
if n_categories_per_cat_column is None:
raise ValueError(err_msg)
for col, num_cat in zip(categorical_columns, n_categories_per_cat_column):
if num_cat < MIN_CATEGORIES_FOR_EMBEDDING_MAX:
multipliers.append(num_cat * arr_dtypes[col].itemsize)
else:
multipliers.append(arr_dtypes[col].itemsize)
size_one_row = sum(multipliers)

elif isinstance(arr, (np.ndarray, spmatrix)):
n_cols = arr.shape[-1] - len(categorical_columns)
multiplier = arr.dtype.itemsize
if len(categorical_columns) > 0:
if n_categories_per_cat_column is None:
raise ValueError(err_msg)
# multiply num categories with the size of the column to capture memory after one hot encoding
n_cols += sum(num_cat if num_cat < MIN_CATEGORIES_FOR_EMBEDDING_MAX else 1 for num_cat in n_categories_per_cat_column)
size_one_row = n_cols * multiplier
else:
raise ValueError(f"Unrecognised data type of X, expected data type to "
f"be in (np.ndarray, spmatrix, pd.DataFrame), but got :{type(arr)}")

return float(memory_in_bytes / (2**20))
return float(arr.shape[0] * size_one_row / (2**20))


def reduce_dataset_size_if_too_large(
X: DatasetCompressionInputType,
memory_allocation: Union[int, float],
is_classification: bool,
random_state: Union[int, np.random.RandomState],
categorical_columns: List,
n_categories_per_cat_column: Optional[List[int]] = None,
y: Optional[SupportedTargetTypes] = None,
methods: List[str] = ['precision', 'subsample'],
) -> DatasetCompressionInputType:
f""" Reduces the size of the dataset if it's too close to the memory limit.
f"""
Reduces the size of the dataset if it's too close to the memory limit.
Follows the order of the operations passed in and retains the type of its
input.
Expand Down Expand Up @@ -513,7 +552,6 @@ def reduce_dataset_size_if_too_large(
Reduce the amount of samples of the dataset such that it fits into the allocated
memory. Ensures stratification and that unique labels are present
memory_allocation (Union[int, float]):
The amount of memory to allocate to the dataset. It should specify an
absolute amount.
Expand All @@ -524,7 +562,7 @@ def reduce_dataset_size_if_too_large(
"""

for method in methods:
if megabytes(X) <= memory_allocation:
if get_approximate_mem_usage_in_mb(X, categorical_columns, n_categories_per_cat_column) <= memory_allocation:
break

if method == 'precision':
Expand All @@ -540,7 +578,8 @@ def reduce_dataset_size_if_too_large(
# into the allocated memory, we subsample it so that it does

n_samples_before = X.shape[0]
sample_percentage = memory_allocation / megabytes(X)
sample_percentage = memory_allocation / get_approximate_mem_usage_in_mb(
X, categorical_columns, n_categories_per_cat_column)

# NOTE: type ignore
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = N
self.add_fit_requirements([
FitRequirement('numerical_columns', (List,), user_defined=True, dataset_property=True),
FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True)])


def get_column_transformer(self) -> ColumnTransformer:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np


from autoPyTorch.constants import MIN_CATEGORIES_FOR_EMBEDDING_MAX
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.base_tabular_preprocessing import \
autoPyTorchTabularPreprocessingComponent
Expand Down Expand Up @@ -72,7 +72,7 @@ def get_hyperparameter_search_space(
dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None,
min_categories_for_embedding: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="min_categories_for_embedding",
value_range=(3, 7),
value_range=(3, MIN_CATEGORIES_FOR_EMBEDDING_MAX),
default_value=3,
log=True),
) -> ConfigurationSpace:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEncoder:
# It is safer to have the OHE produce a 0 array than to crash a good configuration
categories='auto',
sparse=False,
handle_unknown='ignore')
handle_unknown='ignore',
dtype=np.float32)
return self

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions test/test_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,10 +682,10 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular):
del estimator


@pytest.skip("Fix with new portfolio PR")
@unittest.mock.patch('autoPyTorch.evaluation.tae.eval_train_function',
new=dummy_eval_train_function)
@pytest.mark.parametrize('openml_id', (40981, ))
@pytest.mark.skip(reason="Fix with new portfolio PR")
def test_portfolio_selection(openml_id, backend, n_samples):

# Get the data and check that contents of data-manager make sense
Expand Down Expand Up @@ -725,7 +725,7 @@ def test_portfolio_selection(openml_id, backend, n_samples):
assert any(successful_config in portfolio_configs for successful_config in successful_configs)


@pytest.skip("Fix with new portfolio PR")
@pytest.mark.skip(reason="Fix with new portfolio PR")
@unittest.mock.patch('autoPyTorch.evaluation.tae.eval_train_function',
new=dummy_eval_train_function)
@pytest.mark.parametrize('openml_id', (40981, ))
Expand Down
21 changes: 16 additions & 5 deletions test/test_data/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from autoPyTorch.data.utils import (
default_dataset_compression_arg,
get_dataset_compression_mapping,
megabytes,
get_raw_memory_usage,
reduce_dataset_size_if_too_large,
reduce_precision,
subsample,
Expand All @@ -45,13 +45,14 @@ def test_reduce_dataset_if_too_large(openmlid, as_frame, n_samples):
X.copy(),
y=y.copy(),
is_classification=True,
categorical_columns=[],
random_state=1,
memory_allocation=0.001)
memory_allocation=0.01)

assert X_converted.shape[0] < X.shape[0]
assert y_converted.shape[0] < y.shape[0]

assert megabytes(X_converted) < megabytes(X)
assert get_raw_memory_usage(X_converted) < get_raw_memory_usage(X)


@pytest.mark.parametrize("X", [np.asarray([[1, 1, 1]] * 30)])
Expand Down Expand Up @@ -211,8 +212,18 @@ def test_unsupported_errors():
['a', 'b', 'c', 'a', 'b', 'c'],
['a', 'b', 'd', 'r', 'b', 'c']])
with pytest.raises(ValueError, match=r'X.dtype = .*'):
reduce_dataset_size_if_too_large(X, is_classification=True, random_state=1, memory_allocation=0)
reduce_dataset_size_if_too_large(
X,
is_classification=True,
categorical_columns=[],
random_state=1,
memory_allocation=0)

X = [[1, 2], [2, 3]]
with pytest.raises(ValueError, match=r'Unrecognised data type of X, expected data type to be in .*'):
reduce_dataset_size_if_too_large(X, is_classification=True, random_state=1, memory_allocation=0)
reduce_dataset_size_if_too_large(
X,
is_classification=True,
categorical_columns=[],
random_state=1,
memory_allocation=0)
29 changes: 25 additions & 4 deletions test/test_data/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import sklearn.model_selection

from autoPyTorch.data.tabular_validator import TabularInputValidator
from autoPyTorch.data.utils import megabytes
from autoPyTorch.data.utils import get_approximate_mem_usage_in_mb
from autoPyTorch.utils.common import ispandas


@pytest.mark.parametrize('openmlid', [2, 40975, 40984])
Expand Down Expand Up @@ -148,16 +149,36 @@ def test_featurevalidator_dataset_compression(input_data_featuretest):
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
input_data_featuretest, input_data_targets, test_size=0.1, random_state=1)
validator = TabularInputValidator(
dataset_compression={'memory_allocation': 0.8 * megabytes(X_train), 'methods': ['precision', 'subsample']}
dataset_compression={
'memory_allocation': 0.8 * get_approximate_mem_usage_in_mb(X_train, [], None),
'methods': ['precision', 'subsample']}
)
validator.fit(X_train=X_train, y_train=y_train)
transformed_X_train, _ = validator.transform(X_train.copy(), y_train.copy())

if ispandas(X_train):
# input validator converts transformed_X_train to numpy and the cat columns are chosen as column indices
columns = X_train.columns
categorical_columns = [columns[col] for col in validator.feature_validator.categorical_columns]
else:
categorical_columns = validator.feature_validator.categorical_columns

assert validator._reduced_dtype is not None
assert megabytes(transformed_X_train) < megabytes(X_train)
assert get_approximate_mem_usage_in_mb(
transformed_X_train,
validator.feature_validator.categorical_columns,
validator.feature_validator.num_categories_per_col
) < get_approximate_mem_usage_in_mb(
X_train, categorical_columns, validator.feature_validator.num_categories_per_col)

transformed_X_test, _ = validator.transform(X_test.copy(), y_test.copy())
assert megabytes(transformed_X_test) < megabytes(X_test)
assert get_approximate_mem_usage_in_mb(
transformed_X_test,
validator.feature_validator.categorical_columns,
validator.feature_validator.num_categories_per_col
) < get_approximate_mem_usage_in_mb(
X_test, categorical_columns, validator.feature_validator.num_categories_per_col)

if hasattr(transformed_X_train, 'iloc'):
assert all(transformed_X_train.dtypes == transformed_X_test.dtypes)
assert all(transformed_X_train.dtypes == validator._precision)
Expand Down

0 comments on commit 09fdc0d

Please sign in to comment.