diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 49d241daa..6b1631c0b 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -22,7 +22,7 @@ def _create_column_transformer( - preprocessors: Dict[str], + preprocessors: Dict, numerical_columns: List[str], categorical_columns: List[str], ) -> ColumnTransformer: @@ -329,14 +329,14 @@ def _get_columns_info( # Make sure each column is a valid type for i, column in enumerate(X.columns): column_dtype = self.dtypes[i] - if column_dtype.name in ['category', 'bool']: + if column_dtype in ['category', 'bool']: categorical_columns.append(column) feat_type.append('categorical') # Move away from np.issubdtype as it causes # TypeError: data type not understood in certain pandas types elif not is_numeric_dtype(column_dtype): # TODO verify how would this happen when we always convert the object dtypes to category - if column_dtype.name == 'object': + if column_dtype == 'object': raise ValueError( "Input Column {} has invalid type object. " "Cast it to a valid dtype before using it in AutoPyTorch. " @@ -368,7 +368,7 @@ def _get_columns_info( "Make sure your data is formatted in a correct way, " "before feeding it to AutoPyTorch.".format( column, - column_dtype.name, + column_dtype, ) ) else: