diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index c2e220875..14aa6ab83 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -397,6 +397,7 @@ def _clean_logger(self) -> None: self.logging_server.join(timeout=5) self.logging_server.terminate() del self.stop_logging_server + self._logger = None def _create_dask_client(self) -> None: """ @@ -491,6 +492,23 @@ def _load_models(self) -> bool: return True + def _cleanup(self) -> None: + """ + Closes the different servers created during api search. + Returns: + None + """ + if hasattr(self, '_logger') and self._logger is not None: + self._logger.info("Closing the dask infrastructure") + self._close_dask_client() + self._logger.info("Finished closing the dask infrastructure") + + # Clean up the logger + self._logger.info("Starting to clean up the logger") + self._clean_logger() + else: + self._close_dask_client() + def _load_best_individual_model(self) -> SingleBest: """ In case of failure during ensemble building, @@ -923,6 +941,8 @@ def _search( self._stopwatch.stop_task(traditional_task_name) # ============> Starting ensemble + self.precision = precision + self.opt_metric = optimize_metric elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name) time_left_for_ensembles = max(0, total_walltime_limit - elapsed_time) proc_ensemble = None @@ -1024,18 +1044,12 @@ def _search( pd.DataFrame(self.ensemble_performance_history).to_json( os.path.join(self._backend.internals_directory, 'ensemble_history.json')) - self._logger.info("Closing the dask infrastructure") - self._close_dask_client() - self._logger.info("Finished closing the dask infrastructure") - if load_models: self._logger.info("Loading models...") self._load_models() self._logger.info("Finished loading models...") - # Clean up the logger - self._logger.info("Starting to clean up the logger") - self._clean_logger() + self._cleanup() return self @@ -1230,7 +1244,7 @@ def fit_pipeline(self, dataset_requirements = get_dataset_requirements( info=self._get_required_dataset_properties(dataset)) dataset_properties = dataset.get_dataset_properties(dataset_requirements) - self._backend.save_datamanager(dataset) + self._backend.replace_datamanager(dataset) if self._logger is None: self._logger = self._get_logger(dataset.dataset_name) @@ -1506,7 +1520,7 @@ def predict( predictions = self.ensemble_.predict(all_predictions) - self._clean_logger() + self._cleanup() return predictions @@ -1543,10 +1557,7 @@ def __getstate__(self) -> Dict[str, Any]: return self.__dict__ def __del__(self) -> None: - # Clean up the logger - self._clean_logger() - - self._close_dask_client() + self._cleanup() # When a multiprocessing work is done, the # objects are deleted. We don't want to delete run areas diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 16185817b..28d64a4b1 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -41,26 +41,26 @@ def get_tabular_preprocessors(): preprocessors['numerical'] = list() preprocessors['categorical'] = list() + # preprocessors['categorical'].append(SimpleImputer(strategy='constant', + # # Train data is numpy + # # as of this point, where + # # Ordinal Encoding is using + # # for categorical. Only + # # Numbers are allowed + # # fill_value='!missing!', + # fill_value=-1, + # copy=False)) + + # preprocessors['categorical'].append(OrdinalEncoder( + # handle_unknown='use_encoded_value', + # unknown_value=-1)) + preprocessors['categorical'].append(OneHotEncoder( categories='auto', sparse=False, handle_unknown='ignore')) - preprocessors['categorical'].append(SimpleImputer(strategy='constant', - # Train data is numpy - # as of this point, where - # Ordinal Encoding is using - # for categorical. Only - # Numbers are allowed - # fill_value='!missing!', - fill_value=-1, - copy=False)) - - preprocessors['categorical'].append(OrdinalEncoder( - handle_unknown='use_encoded_value', - unknown_value=-1)) - preprocessors['numerical'].append(SimpleImputer(strategy='median', - copy=False)) + copy=False)) preprocessors['numerical'].append(StandardScaler(with_mean=True, with_std=True, copy=False)) return preprocessors diff --git a/autoPyTorch/pipeline/base_pipeline.py b/autoPyTorch/pipeline/base_pipeline.py index 842f63271..80d59a68f 100644 --- a/autoPyTorch/pipeline/base_pipeline.py +++ b/autoPyTorch/pipeline/base_pipeline.py @@ -451,12 +451,13 @@ def _check_search_space_updates(self, include: Optional[Dict[str, Any]], continue raise ValueError("Unknown hyperparameter for component {}. " "Expected update hyperparameter " - "to be in {} got {}".format(node.__class__.__name__, + "to be in {} got {}. choice is {}".format(node.__class__.__name__, component. get_hyperparameter_search_space( dataset_properties=self.dataset_properties). get_hyperparameter_names(), - split_hyperparameter[1])) + split_hyperparameter[1], + component.__name__)) else: if update.hyperparameter not in node.get_hyperparameter_search_space( dataset_properties=self.dataset_properties): diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py index 5fcf5cfb5..c7ca61e09 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py @@ -23,7 +23,6 @@ def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = N self.add_fit_requirements([ FitRequirement('numerical_columns', (List,), user_defined=True, dataset_property=True), FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True)]) - self.fit_time = None def get_column_transformer(self) -> ColumnTransformer: """ @@ -48,7 +47,6 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "TabularColumnTransformer": Returns: "TabularColumnTransformer": an instance of self """ - start_time = time.time() self.check_requirements(X, y) numerical_pipeline = 'passthrough' @@ -74,7 +72,6 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "TabularColumnTransformer": X_train = X['backend'].load_datamanager().train_tensors[0] self.preprocessor.fit(X_train) - self.fit_time = time.time() - start_time return self diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py index 069ca4679..10f509741 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py @@ -139,6 +139,14 @@ def get_hyperparameter_search_space( value_range=(True, False), default_value=True, ), + shake_alpha_beta_method: HyperparameterSearchSpace = HyperparameterSearchSpace( + hyperparameter="shake_alpha_beta_method", + value_range=('shake-shake', + 'shake-even', + 'even-even', + 'M3'), + default_value='shake-shake', + ), use_shake_drop: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="use_shake_drop", value_range=(True, False), default_value=True, @@ -180,9 +188,8 @@ def get_hyperparameter_search_space( if skip_connection_flag: - shake_drop_prob_flag = False - if 'shake-drop' in multi_branch_choice.value_range: - shake_drop_prob_flag = True + shake_shake_flag = 'shake-shake' in multi_branch_choice.value_range + shake_drop_prob_flag = 'shake-drop' in multi_branch_choice.value_range mb_choice = get_hyperparameter(multi_branch_choice, CategoricalHyperparameter) cs.add_hyperparameter(mb_choice) @@ -192,6 +199,10 @@ def get_hyperparameter_search_space( shake_drop_prob = get_hyperparameter(max_shake_drop_probability, UniformFloatHyperparameter) cs.add_hyperparameter(shake_drop_prob) cs.add_condition(CS.EqualsCondition(shake_drop_prob, mb_choice, "shake-drop")) + if shake_shake_flag or shake_drop_prob_flag: + method = get_hyperparameter(shake_alpha_beta_method, CategoricalHyperparameter) + cs.add_hyperparameter(method) + cs.add_condition(CS.InCondition(method, mb_choice, ["shake-shake", "shake-drop"])) # It is the upper bound of the nr of groups, # since the configuration will actually be sampled. @@ -327,11 +338,14 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: if self.config["multi_branch_choice"] == 'shake-shake': x1 = self.layers(x) x2 = self.shake_shake_layers(x) - alpha, beta = shake_get_alpha_beta(self.training, x.is_cuda) + alpha, beta = shake_get_alpha_beta(is_training=self.training, + is_cuda=x.is_cuda, + method=self.config['shake_alpha_beta_method']) x = shake_shake(x1, x2, alpha, beta) elif self.config["multi_branch_choice"] == 'shake-drop': x = self.layers(x) - alpha, beta = shake_get_alpha_beta(self.training, x.is_cuda) + alpha, beta = shake_get_alpha_beta(self.training, x.is_cuda, + method=self.config['shake_alpha_beta_method']) bl = shake_drop_get_bl( self.block_index, 1 - self.config["max_shake_drop_probability"], diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py index e0867cdd3..12c6d4e74 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py @@ -145,6 +145,14 @@ def get_hyperparameter_search_space( # type: ignore[override] 'stairs'), default_value='funnel', ), + shake_alpha_beta_method: HyperparameterSearchSpace = HyperparameterSearchSpace( + hyperparameter="shake_alpha_beta_method", + value_range=('shake-shake', + 'shake-even', + 'even-even', + 'M3'), + default_value='shake-shake', + ), max_shake_drop_probability: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="max_shake_drop_probability", value_range=(0, 1), @@ -188,9 +196,8 @@ def get_hyperparameter_search_space( # type: ignore[override] if skip_connection_flag: - shake_drop_prob_flag = False - if 'shake-drop' in multi_branch_choice.value_range: - shake_drop_prob_flag = True + shake_shake_flag = 'shake-shake' in multi_branch_choice.value_range + shake_drop_prob_flag = 'shake-drop' in multi_branch_choice.value_range mb_choice = get_hyperparameter(multi_branch_choice, CategoricalHyperparameter) cs.add_hyperparameter(mb_choice) @@ -200,5 +207,9 @@ def get_hyperparameter_search_space( # type: ignore[override] shake_drop_prob = get_hyperparameter(max_shake_drop_probability, UniformFloatHyperparameter) cs.add_hyperparameter(shake_drop_prob) cs.add_condition(CS.EqualsCondition(shake_drop_prob, mb_choice, "shake-drop")) + if shake_shake_flag or shake_drop_prob_flag: + method = get_hyperparameter(shake_alpha_beta_method, CategoricalHyperparameter) + cs.add_hyperparameter(method) + cs.add_condition(CS.InCondition(method, mb_choice, ["shake-shake", "shake-drop"])) return cs diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/utils.py b/autoPyTorch/pipeline/components/setup/network_backbone/utils.py index ef19beac8..9a1f9dd4e 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/utils.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/utils.py @@ -92,15 +92,35 @@ def backward(ctx: typing.Any, shake_drop = ShakeDropFunction.apply -def shake_get_alpha_beta(is_training: bool, is_cuda: bool - ) -> typing.Tuple[torch.tensor, torch.tensor]: +def shake_get_alpha_beta( + is_training: bool, + is_cuda: bool, + method: str +) -> typing.Tuple[torch.tensor, torch.tensor]: + """ + The methods used in this function have been introduced in 'ShakeShake Regularisation' + https://arxiv.org/abs/1705.07485. The names have been taken from the paper as well. + """ if not is_training: result = (torch.FloatTensor([0.5]), torch.FloatTensor([0.5])) return result if not is_cuda else (result[0].cuda(), result[1].cuda()) # TODO implement other update methods - alpha = torch.rand(1) - beta = torch.rand(1) + if method == 'even-even': + alpha = torch.FloatTensor([0.5]) + else: + alpha = torch.rand(1) + + if method == 'shake-shake': + beta = torch.rand(1) + elif method in ['shake-even', 'even-even']: + beta = torch.FloatTensor([0.5]) + elif method == 'M3': + beta = torch.FloatTensor( + [torch.rand(1)*(0.5 - alpha)*alpha if alpha < 0.5 else torch.rand(1)*(alpha - 0.5)*alpha] + ) + else: + raise ValueError("Unknown method for ShakeShakeRegularisation in NetworkBackbone") if is_cuda: alpha = alpha.cuda() diff --git a/autoPyTorch/pipeline/components/setup/optimizer/AdamWOptimizer.py b/autoPyTorch/pipeline/components/setup/optimizer/AdamWOptimizer.py index 4d11c3026..a415ff1c6 100644 --- a/autoPyTorch/pipeline/components/setup/optimizer/AdamWOptimizer.py +++ b/autoPyTorch/pipeline/components/setup/optimizer/AdamWOptimizer.py @@ -95,9 +95,9 @@ def get_hyperparameter_search_space( default_value=True, ), weight_decay: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="weight_decay", - value_range=(1E-7, 0.1), + value_range=(1E-5, 0.1), default_value=1E-4, - log=True), + log=False), ) -> ConfigurationSpace: cs = ConfigurationSpace() diff --git a/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py b/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py index 5b8e445ac..7302ac6f5 100644 --- a/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py +++ b/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py @@ -115,7 +115,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> torch.utils.data.DataLoader: shuffle=True, num_workers=X.get('num_workers', 0), pin_memory=X.get('pin_memory', True), - drop_last=X.get('drop_last', True), + drop_last=X.get('drop_last', False), collate_fn=custom_collate_fn, ) diff --git a/autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py b/autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py index c5a536dd0..36d586919 100644 --- a/autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py @@ -189,12 +189,17 @@ def get_hyperparameter_search_space( default_value=3), epsilon: HyperparameterSearchSpace = HyperparameterSearchSpace( hyperparameter="epsilon", - value_range=(0.05, 0.2), - default_value=0.2), + value_range=(0.001, 0.15), + default_value=0.007, + log=True), ) -> ConfigurationSpace: cs = ConfigurationSpace() + epsilon = HyperparameterSearchSpace(hyperparameter="epsilon", + value_range=(0.007, 0.007), + default_value=0.007) add_hyperparameter(cs, epsilon, UniformFloatHyperparameter) + add_hyperparameter(cs, use_stochastic_weight_averaging, CategoricalHyperparameter) snapshot_ensemble_flag = False if any(use_snapshot_ensemble.value_range): diff --git a/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py b/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py index 20d02c793..f1b606046 100644 --- a/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py @@ -36,7 +36,7 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray, return X, {'y_a': y, 'y_b': y[index], 'lam': 1} size = X.shape[1] - indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int32(size * lam)), + indices = torch.tensor(self.random_state.choice(range(size), max(1, np.int32(size * lam)), replace=False)) X[:, indices] = X[index, :][:, indices] diff --git a/autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py b/autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py index c09603523..d7bd23f4e 100644 --- a/autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py @@ -37,7 +37,7 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray, return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam} size = X.shape[1] - indices = self.random_state.choice(range(1, size), max(1, np.int32(size * self.patch_ratio)), + indices = self.random_state.choice(range(size), max(1, np.int32(size * self.patch_ratio)), replace=False) """if not isinstance(self.numerical_columns, typing.Iterable): diff --git a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py index 188504da3..6040f32e9 100644 --- a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py @@ -371,16 +371,12 @@ def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int, outputs_data = list() targets_data = list() - batch_load_start_time = time.time() for step, (data, targets) in enumerate(train_loader): - self.data_loading_times.append(time.time() - batch_load_start_time) - batch_train_start = time.time() if self.budget_tracker.is_max_time_reached(): break loss, outputs = self.train_step(data, targets) - self.batch_fit_times.append(time.time() - batch_train_start) # save for metric evaluation outputs_data.append(outputs.detach().cpu()) targets_data.append(targets.detach().cpu()) @@ -395,7 +391,6 @@ def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int, loss, epoch * len(train_loader) + step, ) - batch_load_start_time = time.time() if self.scheduler: if 'ReduceLROnPlateau' in self.scheduler.__class__.__name__: diff --git a/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py b/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py index 2dcb8fe16..a344e92ce 100755 --- a/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py +++ b/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py @@ -77,7 +77,6 @@ def __init__(self, (torch.utils.data.DataLoader,), user_defined=False, dataset_property=False)] self.checkpoint_dir = None # type: Optional[str] - self.fit_time = None def get_fit_requirements(self) -> Optional[List[FitRequirement]]: return self._fit_requirements @@ -264,7 +263,6 @@ def fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> autoPyTorchCom Returns: A instance of self """ - start_time = time.time() # Make sure that the prerequisites are there self.check_requirements(X, y) @@ -287,7 +285,6 @@ def fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> autoPyTorchCom self.choice: autoPyTorchComponent = cast(autoPyTorchComponent, self.choice) if self.choice.use_snapshot_ensemble: X['network_snapshots'].extend(self.choice.model_snapshots) - self.fit_time = time.time() - start_time return self.choice def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoice': @@ -410,7 +407,6 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic # change model update_model_state_dict_from_swa(X['network'], self.choice.swa_model.state_dict()) if self.choice.use_snapshot_ensemble: - swa_utils.update_bn(X['train_data_loader'], model.double()) # we update only the last network which pertains to the stochastic weight averaging model swa_utils.update_bn(X['train_data_loader'], self.choice.model_snapshots[-1].double()) diff --git a/autoPyTorch/pipeline/components/training/trainer/cutout_utils.py b/autoPyTorch/pipeline/components/training/trainer/cutout_utils.py index c58546a4c..c7feb2214 100644 --- a/autoPyTorch/pipeline/components/training/trainer/cutout_utils.py +++ b/autoPyTorch/pipeline/components/training/trainer/cutout_utils.py @@ -53,8 +53,6 @@ def __init__(self, patch_ratio: float, self.lookahead_config = lookahead_config self.patch_ratio = patch_ratio self.cutout_prob = cutout_prob - self.batch_fit_times = [] - self.data_loading_times = [] def criterion_preparation(self, y_a: np.ndarray, y_b: np.ndarray = None, lam: float = 1.0 ) -> Callable: diff --git a/autoPyTorch/pipeline/components/training/trainer/mixup_utils.py b/autoPyTorch/pipeline/components/training/trainer/mixup_utils.py index b1cf37972..a2325b91c 100644 --- a/autoPyTorch/pipeline/components/training/trainer/mixup_utils.py +++ b/autoPyTorch/pipeline/components/training/trainer/mixup_utils.py @@ -51,8 +51,6 @@ def __init__(self, alpha: float, f'{Lookahead.__name__}:la_alpha': 0.6} self.lookahead_config = lookahead_config self.alpha = alpha - self.batch_fit_times = [] - self.data_loading_times = [] def criterion_preparation(self, y_a: np.ndarray, y_b: np.ndarray = None, lam: float = 1.0 ) -> Callable: diff --git a/autoPyTorch/utils/backend.py b/autoPyTorch/utils/backend.py index 713c7d572..7a7399a9f 100644 --- a/autoPyTorch/utils/backend.py +++ b/autoPyTorch/utils/backend.py @@ -328,6 +328,11 @@ def load_datamanager(self) -> BaseDataset: with open(filepath, 'rb') as fh: return pickle.load(fh) + def replace_datamanager(self, datamanager: BaseDataset): + warnings.warn("Original dataset will be overwritten with the provided dataset") + os.remove(self._get_datamanager_pickle_filename()) + self.save_datamanager(datamanager=datamanager) + def get_runs_directory(self) -> str: return os.path.join(self.internals_directory, 'runs') diff --git a/examples/tabular/40_advanced/example_custom_configuration_space.py b/examples/tabular/40_advanced/example_custom_configuration_space.py index 6a3764b94..b95ceeaa5 100644 --- a/examples/tabular/40_advanced/example_custom_configuration_space.py +++ b/examples/tabular/40_advanced/example_custom_configuration_space.py @@ -54,6 +54,15 @@ def get_search_space_updates(): hyperparameter='ResNetBackbone:dropout', value_range=[0, 0.5], default_value=0.2) + updates.append(node_name='network_backbone', + hyperparameter='ResNetBackbone:multi_branch_choice', + value_range=['shake-shake'], + default_value='shake-shake') + updates.append(node_name='network_backbone', + hyperparameter='ResNetBackbone:shake_shake_method', + value_range=['M3'], + default_value='M3' + ) return updates @@ -74,7 +83,7 @@ def get_search_space_updates(): # ================================================== api = TabularClassificationTask( search_space_updates=get_search_space_updates(), - include_components={'network_backbone': ['MLPBackbone', 'ResNetBackbone'], + include_components={'network_backbone': ['ResNetBackbone'], 'encoder': ['OneHotEncoder']} )