Skip to content

Commit

Permalink
Fixing issues with imbalanced datasets (#197)
Browse files Browse the repository at this point in the history
* adding missing method from base_feature_validator

* First try at a fix, removing redundant code

* Fix bug

* Updating unit test typo, fixing bug where the data type was not checked because X was a numpy array at the time of checking

* Fixing flake 8 failing

* Bug fix, implementation update for imbalanced datasets and unit tests to check the implementation

* flake8 fix

* Bug fix

* Making the conversion to dataframe in the unit tests consistent with what happens at the validator, so the types do not change

* flake8 fix

* Addressing Ravin's comments
  • Loading branch information
ArlindKadra authored and ravinkohli committed Jan 27, 2022
1 parent e232f61 commit e9eb006
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 36 deletions.
14 changes: 14 additions & 0 deletions autoPyTorch/data/base_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,20 @@ def _fit(
"""
raise NotImplementedError()

def _check_data(
self,
X: SUPPORTED_FEAT_TYPES,
) -> None:
"""
Feature dimensionality and data type checks
Arguments:
X (SUPPORTED_FEAT_TYPES):
A set of features that are going to be validated (type and dimensionality
checks) and a encoder fitted in the case the data needs encoding
"""
raise NotImplementedError()

def transform(
self,
X: SUPPORTED_FEAT_TYPES,
Expand Down
68 changes: 37 additions & 31 deletions autoPyTorch/data/tabular_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,13 @@ def _fit(
# with nan values.
# Columns that are completely made of NaN values are provided to the pipeline
# so that later stages decide how to handle them

# Clear whatever null column markers we had previously
self.null_columns.clear()
if np.any(pd.isnull(X)):
for column in X.columns:
if X[column].isna().all():
self.null_columns.add(column)
X[column] = pd.to_numeric(X[column])
# Also note this change in self.dtypes
if len(self.dtypes) != 0:
Expand All @@ -158,9 +162,8 @@ def _fit(
if not X.select_dtypes(include='object').empty:
X = self.infer_objects(X)

self.transformed_columns, self.feat_type = self._get_columns_to_encode(X)

assert self.feat_type is not None
self._check_data(X)
self.enc_columns, self.feat_type = self._get_columns_to_encode(X)

if len(self.transformed_columns) > 0:

Expand Down Expand Up @@ -230,29 +233,37 @@ def transform(
X = self.numpy_array_to_pandas(X)

if hasattr(X, "iloc") and not scipy.sparse.issparse(X):
if np.any(pd.isnull(X)):
for column in X.columns:
if X[column].isna().all():
X[column] = pd.to_numeric(X[column])
X = typing.cast(pd.DataFrame, X)
# If we had null columns in our fit call and we made them numeric, then:
# - If the columns are null even in transform, apply the same procedure.
# - Otherwise, substitute the values with np.NaN and then make the columns numeric.
# If the column is null here, but it was not in fit, it does not matter.
for column in self.null_columns:
# The column is not null, make it null since it was null in fit.
if not X[column].isna().all():
X[column] = np.NaN
X[column] = pd.to_numeric(X[column])

# for the test set, if we have columns with only null values
# they will probably have a numeric type. If these columns were not
# with only null values in the train set, they should be converted
# to the type that they had during fitting.
for column in X.columns:
if X[column].isna().all():
X[column] = X[column].astype(self.dtypes[list(X.columns).index(column)])

# Also remove the object dtype for new data
if not X.select_dtypes(include='object').empty:
X = self.infer_objects(X)

# Check the data here so we catch problems on new test data
self._check_data(X)
# We also need to fillna on the transformation
# in case test data is provided
X = self.impute_nan_in_categories(X)

# Pandas related transformations
if hasattr(X, "iloc") and self.column_transformer is not None:
if np.any(pd.isnull(X)):
# After above check it means that if there is a NaN
# the whole column must be NaN
# Make sure it is numerical and let the pipeline handle it
for column in X.columns:
if X[column].isna().all():
X[column] = pd.to_numeric(X[column])

X = self.column_transformer.transform(X)
if self.encoder is not None:
X = self.encoder.transform(X)

# Sparse related transformations
# Not all sparse format support index sorting
Expand Down Expand Up @@ -478,7 +489,7 @@ def numpy_array_to_pandas(
Returns:
pd.DataFrame
"""
return pd.DataFrame(X).infer_objects().convert_dtypes()
return pd.DataFrame(X).convert_dtypes()

def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
"""
Expand All @@ -496,18 +507,13 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
if hasattr(self, 'object_dtype_mapping'):
# Mypy does not process the has attr. This dict is defined below
for key, dtype in self.object_dtype_mapping.items(): # type: ignore[has-type]
if 'int' in dtype.name:
# In the case train data was interpreted as int
# and test data was interpreted as float, because of 0.0
# for example, honor training data
X[key] = X[key].applymap(np.int64)
else:
try:
X[key] = X[key].astype(dtype.name)
except Exception as e:
# Try inference if possible
self.logger.warning(f"Tried to cast column {key} to {dtype} caused {e}")
pass
# honor the training data types
try:
X[key] = X[key].astype(dtype.name)
except Exception as e:
# Try inference if possible
self.logger.warning(f"Tried to cast column {key} to {dtype} caused {e}")
pass
else:
X = X.infer_objects()
for column in X.columns:
Expand Down
71 changes: 67 additions & 4 deletions test/test_data/test_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,9 @@ def test_featurevalidator_fitontypeA_transformtypeB(input_data_featuretest):
if isinstance(input_data_featuretest, pd.DataFrame):
pytest.skip("Column order change in pandas is not supported")
elif isinstance(input_data_featuretest, np.ndarray):
complementary_type = pd.DataFrame(input_data_featuretest)
complementary_type = validator.numpy_array_to_pandas(input_data_featuretest)
elif isinstance(input_data_featuretest, list):
complementary_type = pd.DataFrame(input_data_featuretest)
complementary_type, _ = validator.list_to_dataframe(input_data_featuretest)
elif sparse.issparse(input_data_featuretest):
complementary_type = sparse.csr_matrix(input_data_featuretest.todense())
else:
Expand Down Expand Up @@ -478,8 +478,11 @@ def test_unknown_encode_value():
)
@pytest.mark.parametrize('train_data_type', ('numpy', 'pandas', 'list'))
@pytest.mark.parametrize('test_data_type', ('numpy', 'pandas', 'list'))
def test_featurevalidator_new_data_after_fit(openml_id,
train_data_type, test_data_type):
def test_feature_validator_new_data_after_fit(
openml_id,
train_data_type,
test_data_type,
):

# List is currently not supported as infer_objects
# cast list objects to type objects
Expand Down Expand Up @@ -557,3 +560,63 @@ def test_comparator():
key=functools.cmp_to_key(validator._comparator)
)
assert ans == feat_type
def test_feature_validator_imbalanced_data():

# Null columns in the train split but not necessarily in the test split
train_features = {
'A': [np.NaN, np.NaN, np.NaN],
'B': [1, 2, 3],
'C': [np.NaN, np.NaN, np.NaN],
'D': [np.NaN, np.NaN, np.NaN],
}
test_features = {
'A': [3, 4, 5],
'B': [6, 5, 7],
'C': [np.NaN, np.NaN, np.NaN],
'D': ['Blue', np.NaN, np.NaN],
}

X_train = pd.DataFrame.from_dict(train_features)
X_test = pd.DataFrame.from_dict(test_features)
validator = TabularFeatureValidator()
validator.fit(X_train)

train_feature_types = copy.deepcopy(validator.feat_type)
assert train_feature_types == ['numerical', 'numerical', 'numerical', 'numerical']
# validator will throw an error if the column types are not the same
transformed_X_test = validator.transform(X_test)
transformed_X_test = pd.DataFrame(transformed_X_test)
null_columns = []
for column in transformed_X_test.columns:
if transformed_X_test[column].isna().all():
null_columns.append(column)
assert null_columns == [0, 2, 3]

# Columns with not all null values in the train split and
# completely null on the test split.
train_features = {
'A': [np.NaN, np.NaN, 4],
'B': [1, 2, 3],
'C': ['Blue', np.NaN, np.NaN],
}
test_features = {
'A': [np.NaN, np.NaN, np.NaN],
'B': [6, 5, 7],
'C': [np.NaN, np.NaN, np.NaN],
}

X_train = pd.DataFrame.from_dict(train_features)
X_test = pd.DataFrame.from_dict(test_features)
validator = TabularFeatureValidator()
validator.fit(X_train)
train_feature_types = copy.deepcopy(validator.feat_type)
assert train_feature_types == ['categorical', 'numerical', 'numerical']

transformed_X_test = validator.transform(X_test)
transformed_X_test = pd.DataFrame(transformed_X_test)
null_columns = []
for column in transformed_X_test.columns:
if transformed_X_test[column].isna().all():
null_columns.append(column)

assert null_columns == [1]
1 change: 0 additions & 1 deletion test/test_data/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def test_data_validation_for_classification(openmlid, as_frame):
x, y, test_size=0.33, random_state=0)

validator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)

X_train_t, y_train_t = validator.transform(X_train, y_train)
assert np.shape(X_train) == np.shape(X_train_t)

Expand Down

0 comments on commit e9eb006

Please sign in to comment.