From 8bf6280d590b63443e8fc77be01eb4dbbd3d5509 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 28 Feb 2022 15:22:44 +0100 Subject: [PATCH 1/4] Enable learned embeddings, fix bug with non cyclic schedulers --- .../setup/network_embedding/__init__.py | 101 ++++++++---------- .../training/trainer/base_trainer.py | 71 ++++++------ 2 files changed, 86 insertions(+), 86 deletions(-) diff --git a/autoPyTorch/pipeline/components/setup/network_embedding/__init__.py b/autoPyTorch/pipeline/components/setup/network_embedding/__init__.py index d59597040..381e0735d 100644 --- a/autoPyTorch/pipeline/components/setup/network_embedding/__init__.py +++ b/autoPyTorch/pipeline/components/setup/network_embedding/__init__.py @@ -148,71 +148,62 @@ def get_hyperparameter_search_space( if default is None: defaults = [ 'NoEmbedding', - # 'LearnedEntityEmbedding', + 'LearnedEntityEmbedding', ] for default_ in defaults: if default_ in available_embedding: default = default_ break - # Restrict embedding to NoEmbedding until preprocessing is fixed - embedding = CSH.CategoricalHyperparameter('__choice__', - ['NoEmbedding'], - default_value=default) + categorical_columns = dataset_properties['categorical_columns'] \ + if isinstance(dataset_properties['categorical_columns'], List) else [] + + updates = self._get_search_space_updates() + if '__choice__' in updates.keys(): + choice_hyperparameter = updates['__choice__'] + if not set(choice_hyperparameter.value_range).issubset(available_embedding): + raise ValueError("Expected given update for {} to have " + "choices in {} got {}".format(self.__class__.__name__, + available_embedding, + choice_hyperparameter.value_range)) + if len(categorical_columns) == 0: + assert len(choice_hyperparameter.value_range) == 1 + if 'NoEmbedding' not in choice_hyperparameter.value_range: + raise ValueError("Provided {} in choices, however, the dataset " + "is incompatible with it".format(choice_hyperparameter.value_range)) + embedding = CSH.CategoricalHyperparameter('__choice__', + choice_hyperparameter.value_range, + default_value=choice_hyperparameter.default_value) + else: + + if len(categorical_columns) == 0: + default = 'NoEmbedding' + if include is not None and default not in include: + raise ValueError("Provided {} in include, however, the dataset " + "is incompatible with it".format(include)) + embedding = CSH.CategoricalHyperparameter('__choice__', + ['NoEmbedding'], + default_value=default) + else: + embedding = CSH.CategoricalHyperparameter('__choice__', + list(available_embedding.keys()), + default_value=default) + cs.add_hyperparameter(embedding) + for name in embedding.choices: + updates = self._get_search_space_updates(prefix=name) + config_space = available_embedding[name].get_hyperparameter_search_space(dataset_properties, # type: ignore + **updates) + parent_hyperparameter = {'parent': embedding, 'value': name} + cs.add_configuration_space( + name, + config_space, + parent_hyperparameter=parent_hyperparameter + ) + self.configuration_space_ = cs self.dataset_properties_ = dataset_properties return cs - # categorical_columns = dataset_properties['categorical_columns'] \ - # if isinstance(dataset_properties['categorical_columns'], List) else [] - - # updates = self._get_search_space_updates() - # if '__choice__' in updates.keys(): - # choice_hyperparameter = updates['__choice__'] - # if not set(choice_hyperparameter.value_range).issubset(available_embedding): - # raise ValueError("Expected given update for {} to have " - # "choices in {} got {}".format(self.__class__.__name__, - # available_embedding, - # choice_hyperparameter.value_range)) - # if len(categorical_columns) == 0: - # assert len(choice_hyperparameter.value_range) == 1 - # if 'NoEmbedding' not in choice_hyperparameter.value_range: - # raise ValueError("Provided {} in choices, however, the dataset " - # "is incompatible with it".format(choice_hyperparameter.value_range)) - # embedding = CSH.CategoricalHyperparameter('__choice__', - # choice_hyperparameter.value_range, - # default_value=choice_hyperparameter.default_value) - # else: - - # if len(categorical_columns) == 0: - # default = 'NoEmbedding' - # if include is not None and default not in include: - # raise ValueError("Provided {} in include, however, the dataset " - # "is incompatible with it".format(include)) - # embedding = CSH.CategoricalHyperparameter('__choice__', - # ['NoEmbedding'], - # default_value=default) - # else: - # embedding = CSH.CategoricalHyperparameter('__choice__', - # list(available_embedding.keys()), - # default_value=default) - - # cs.add_hyperparameter(embedding) - # for name in embedding.choices: - # updates = self._get_search_space_updates(prefix=name) - # config_space = available_embedding[name].get_hyperparameter_search_space( - # dataset_properties, # type: ignore - # **updates) - # parent_hyperparameter = {'parent': embedding, 'value': name} - # cs.add_configuration_space( - # name, - # config_space, - # parent_hyperparameter=parent_hyperparameter - # ) - - # self.configuration_space_ = cs - # self.dataset_properties_ = dataset_properties - # return cs def transform(self, X: np.ndarray) -> np.ndarray: assert self.choice is not None, "Cannot call transform before the object is initialized" diff --git a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py index 517ae08bb..cf634953a 100644 --- a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py @@ -334,6 +334,35 @@ def on_epoch_start(self, X: Dict[str, Any], epoch: int) -> None: """ pass + def _swa_update(self) -> None: + """ + perform swa model update + """ + assert self.swa_model is not None, "SWA model can't be none when" \ + " stochastic weight averaging is enabled" + self.swa_model.update_parameters(self.model) + self.swa_updated = True + + def _se_update(self, epoch: int) -> None: + """ + Add latest model or swa_model to model snapshot ensemble + Args: + epoch (int): + current epoch + """ + assert self.model_snapshots is not None, "model snapshots container can't be " \ + "none when snapshot ensembling is enabled" + is_last_epoch = (epoch == self.budget_tracker.max_epochs) + if is_last_epoch and self.use_stochastic_weight_averaging: + model_copy = deepcopy(self.swa_model) + else: + model_copy = deepcopy(self.model) + + assert model_copy is not None + model_copy.cpu() + self.model_snapshots.append(model_copy) + self.model_snapshots = self.model_snapshots[-self.se_lastk:] + def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool: """ Optional place holder for AutoPytorch Extensions. @@ -344,39 +373,19 @@ def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool: if X['is_cyclic_scheduler']: if hasattr(self.scheduler, 'T_cur') and self.scheduler.T_cur == 0 and epoch != 1: if self.use_stochastic_weight_averaging: - assert self.swa_model is not None, "SWA model can't be none when" \ - " stochastic weight averaging is enabled" - self.swa_model.update_parameters(self.model) - self.swa_updated = True + self._swa_update() if self.use_snapshot_ensemble: - assert self.model_snapshots is not None, "model snapshots container can't be " \ - "none when snapshot ensembling is enabled" - is_last_epoch = (epoch == self.budget_tracker.max_epochs) - if is_last_epoch and self.use_stochastic_weight_averaging: - model_copy = deepcopy(self.swa_model) - else: - model_copy = deepcopy(self.model) - - assert model_copy is not None - model_copy.cpu() - self.model_snapshots.append(model_copy) - self.model_snapshots = self.model_snapshots[-self.se_lastk:] + self._se_update(epoch=epoch) else: - if epoch > self._budget_threshold: - if self.use_stochastic_weight_averaging: - assert self.swa_model is not None, "SWA model can't be none when" \ - " stochastic weight averaging is enabled" - self.swa_model.update_parameters(self.model) - self.swa_updated = True - if self.use_snapshot_ensemble: - assert self.model_snapshots is not None, "model snapshots container can't be " \ - "none when snapshot ensembling is enabled" - model_copy = deepcopy(self.swa_model) if self.use_stochastic_weight_averaging \ - else deepcopy(self.model) - assert model_copy is not None - model_copy.cpu() - self.model_snapshots.append(model_copy) - self.model_snapshots = self.model_snapshots[-self.se_lastk:] + if epoch > self._budget_threshold and self.use_stochastic_weight_averaging: + self._swa_update() + + if ( + self.use_snapshot_ensemble + and self.budget_tracker.max_epochs is not None + and epoch > (self.budget_tracker.max_epochs - self.se_lastk) + ): + self._se_update(epoch=epoch) return False def _scheduler_step( From 989f3ac0dccf6b3367f2d9f197057e1893c7d6a3 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 28 Feb 2022 15:31:12 +0100 Subject: [PATCH 2/4] add forbidden condition cyclic lr --- .../pipeline/tabular_classification.py | 23 +++++++++++++++++++ autoPyTorch/pipeline/tabular_regression.py | 23 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/autoPyTorch/pipeline/tabular_classification.py b/autoPyTorch/pipeline/tabular_classification.py index 720d0af64..90acc1656 100644 --- a/autoPyTorch/pipeline/tabular_classification.py +++ b/autoPyTorch/pipeline/tabular_classification.py @@ -289,6 +289,29 @@ def _get_hyperparameter_search_space(self, raise ValueError("Cannot find a legal default configuration") cs.get_hyperparameter('network_embedding:__choice__').default_value = default + # Disable CyclicLR until todo is completed. + if 'lr_scheduler' in self.named_steps.keys() and 'trainer' in self.named_steps.keys(): + trainers = cs.get_hyperparameter('trainer:__choice__').choices + for trainer in trainers: + available_schedulers = self.named_steps['lr_scheduler'].get_available_components( + dataset_properties=dataset_properties, + exclude=exclude if bool(exclude) else None, + include=include if bool(include) else None) + # TODO: update cyclic lr to use n_restarts and adjust according to batch size + cyclic_lr_name = 'CyclicLR' + if cyclic_lr_name in available_schedulers: + # disable snapshot ensembles and stochastic weight averaging + cs.add_forbidden_clause(ForbiddenAndConjunction( + ForbiddenEqualsClause(cs.get_hyperparameter( + f'trainer:{trainer}:use_snapshot_ensemble'), True), + ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name) + )) + cs.add_forbidden_clause(ForbiddenAndConjunction( + ForbiddenEqualsClause(cs.get_hyperparameter( + f'trainer:{trainer}:use_stochastic_weight_averaging'), True), + ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name) + )) + self.configuration_space = cs self.dataset_properties = dataset_properties return cs diff --git a/autoPyTorch/pipeline/tabular_regression.py b/autoPyTorch/pipeline/tabular_regression.py index 06da9cabb..105506ba2 100644 --- a/autoPyTorch/pipeline/tabular_regression.py +++ b/autoPyTorch/pipeline/tabular_regression.py @@ -238,6 +238,29 @@ def _get_hyperparameter_search_space(self, raise ValueError("Cannot find a legal default configuration") cs.get_hyperparameter('network_embedding:__choice__').default_value = default + # Disable CyclicLR until todo is completed. + if 'lr_scheduler' in self.named_steps.keys() and 'trainer' in self.named_steps.keys(): + trainers = cs.get_hyperparameter('trainer:__choice__').choices + for trainer in trainers: + available_schedulers = self.named_steps['lr_scheduler'].get_available_components( + dataset_properties=dataset_properties, + exclude=exclude if bool(exclude) else None, + include=include if bool(include) else None) + # TODO: update cyclic lr to use n_restarts and adjust according to batch size + cyclic_lr_name = 'CyclicLR' + if cyclic_lr_name in available_schedulers: + # disable snapshot ensembles and stochastic weight averaging + cs.add_forbidden_clause(ForbiddenAndConjunction( + ForbiddenEqualsClause(cs.get_hyperparameter( + f'trainer:{trainer}:use_snapshot_ensemble'), True), + ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name) + )) + cs.add_forbidden_clause(ForbiddenAndConjunction( + ForbiddenEqualsClause(cs.get_hyperparameter( + f'trainer:{trainer}:use_stochastic_weight_averaging'), True), + ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name) + )) + self.configuration_space = cs self.dataset_properties = dataset_properties return cs From cda66c89d94ebc285248331d3aa66e6c0f0d01e9 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Wed, 2 Mar 2022 17:10:49 +0100 Subject: [PATCH 3/4] refactor base_pipeline forbidden conditions --- autoPyTorch/pipeline/base_pipeline.py | 63 +++++++++++++++++++ autoPyTorch/pipeline/image_classification.py | 1 + .../pipeline/tabular_classification.py | 54 +--------------- autoPyTorch/pipeline/tabular_regression.py | 52 +-------------- 4 files changed, 68 insertions(+), 102 deletions(-) diff --git a/autoPyTorch/pipeline/base_pipeline.py b/autoPyTorch/pipeline/base_pipeline.py index fc15d9fed..fe9727502 100644 --- a/autoPyTorch/pipeline/base_pipeline.py +++ b/autoPyTorch/pipeline/base_pipeline.py @@ -1,3 +1,4 @@ +from copy import copy import warnings from abc import ABCMeta from collections import Counter @@ -5,6 +6,7 @@ from ConfigSpace import Configuration from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.forbidden import ForbiddenAndConjunction, ForbiddenEqualsClause import numpy as np @@ -295,6 +297,67 @@ def _get_hyperparameter_search_space(self, """ raise NotImplementedError() + def _add_forbidden_conditions(self, cs): + """ + Add forbidden conditions to ensure valid configurations. + Currently, Learned Entity Embedding is only valid when encoder is one hot encoder + and CyclicLR is disabled when using stochastic weight averaging and snapshot + ensembling. + + Args: + cs (ConfigurationSpace): + Configuration space to which forbidden conditions are added. + + """ + + # Learned Entity Embedding is only valid when encoder is one hot encoder + if 'network_embedding' in self.named_steps.keys() and 'encoder' in self.named_steps.keys(): + embeddings = cs.get_hyperparameter('network_embedding:__choice__').choices + if 'LearnedEntityEmbedding' in embeddings: + encoders = cs.get_hyperparameter('encoder:__choice__').choices + possible_default_embeddings = copy(list(embeddings)) + del possible_default_embeddings[possible_default_embeddings.index('LearnedEntityEmbedding')] + + for encoder in encoders: + if encoder == 'OneHotEncoder': + continue + while True: + try: + cs.add_forbidden_clause(ForbiddenAndConjunction( + ForbiddenEqualsClause(cs.get_hyperparameter( + 'network_embedding:__choice__'), 'LearnedEntityEmbedding'), + ForbiddenEqualsClause(cs.get_hyperparameter('encoder:__choice__'), encoder) + )) + break + except ValueError: + # change the default and try again + try: + default = possible_default_embeddings.pop() + except IndexError: + raise ValueError("Cannot find a legal default configuration") + cs.get_hyperparameter('network_embedding:__choice__').default_value = default + + # Disable CyclicLR until todo is completed. + if 'lr_scheduler' in self.named_steps.keys() and 'trainer' in self.named_steps.keys(): + trainers = cs.get_hyperparameter('trainer:__choice__').choices + for trainer in trainers: + available_schedulers = cs.get_hyperparameter('lr_scheduler:__choice__').choices + # TODO: update cyclic lr to use n_restarts and adjust according to batch size + cyclic_lr_name = 'CyclicLR' + if cyclic_lr_name in available_schedulers: + # disable snapshot ensembles and stochastic weight averaging + cs.add_forbidden_clause(ForbiddenAndConjunction( + ForbiddenEqualsClause(cs.get_hyperparameter( + f'trainer:{trainer}:use_snapshot_ensemble'), True), + ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name) + )) + cs.add_forbidden_clause(ForbiddenAndConjunction( + ForbiddenEqualsClause(cs.get_hyperparameter( + f'trainer:{trainer}:use_stochastic_weight_averaging'), True), + ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name) + )) + return cs + def __repr__(self) -> str: """Retrieves a str representation of the current pipeline diff --git a/autoPyTorch/pipeline/image_classification.py b/autoPyTorch/pipeline/image_classification.py index 276e05816..13f8a4cf8 100644 --- a/autoPyTorch/pipeline/image_classification.py +++ b/autoPyTorch/pipeline/image_classification.py @@ -156,6 +156,7 @@ def _get_hyperparameter_search_space(self, # Here we add custom code, like this with this # is not a valid configuration + cs = self._add_forbidden_conditions(cs) self.configuration_space = cs self.dataset_properties = dataset_properties diff --git a/autoPyTorch/pipeline/tabular_classification.py b/autoPyTorch/pipeline/tabular_classification.py index 90acc1656..2e64a6944 100644 --- a/autoPyTorch/pipeline/tabular_classification.py +++ b/autoPyTorch/pipeline/tabular_classification.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union from ConfigSpace.configuration_space import Configuration, ConfigurationSpace -from ConfigSpace.forbidden import ForbiddenAndConjunction, ForbiddenEqualsClause import numpy as np @@ -261,56 +260,9 @@ def _get_hyperparameter_search_space(self, cs=cs, dataset_properties=dataset_properties, exclude=exclude, include=include, pipeline=self.steps) - # Here we add custom code, that is used to ensure valid configurations, For example - # Learned Entity Embedding is only valid when encoder is one hot encoder - if 'network_embedding' in self.named_steps.keys() and 'encoder' in self.named_steps.keys(): - embeddings = cs.get_hyperparameter('network_embedding:__choice__').choices - if 'LearnedEntityEmbedding' in embeddings: - encoders = cs.get_hyperparameter('encoder:__choice__').choices - possible_default_embeddings = copy.copy(list(embeddings)) - del possible_default_embeddings[possible_default_embeddings.index('LearnedEntityEmbedding')] - - for encoder in encoders: - if encoder == 'OneHotEncoder': - continue - while True: - try: - cs.add_forbidden_clause(ForbiddenAndConjunction( - ForbiddenEqualsClause(cs.get_hyperparameter( - 'network_embedding:__choice__'), 'LearnedEntityEmbedding'), - ForbiddenEqualsClause(cs.get_hyperparameter('encoder:__choice__'), encoder) - )) - break - except ValueError: - # change the default and try again - try: - default = possible_default_embeddings.pop() - except IndexError: - raise ValueError("Cannot find a legal default configuration") - cs.get_hyperparameter('network_embedding:__choice__').default_value = default - - # Disable CyclicLR until todo is completed. - if 'lr_scheduler' in self.named_steps.keys() and 'trainer' in self.named_steps.keys(): - trainers = cs.get_hyperparameter('trainer:__choice__').choices - for trainer in trainers: - available_schedulers = self.named_steps['lr_scheduler'].get_available_components( - dataset_properties=dataset_properties, - exclude=exclude if bool(exclude) else None, - include=include if bool(include) else None) - # TODO: update cyclic lr to use n_restarts and adjust according to batch size - cyclic_lr_name = 'CyclicLR' - if cyclic_lr_name in available_schedulers: - # disable snapshot ensembles and stochastic weight averaging - cs.add_forbidden_clause(ForbiddenAndConjunction( - ForbiddenEqualsClause(cs.get_hyperparameter( - f'trainer:{trainer}:use_snapshot_ensemble'), True), - ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name) - )) - cs.add_forbidden_clause(ForbiddenAndConjunction( - ForbiddenEqualsClause(cs.get_hyperparameter( - f'trainer:{trainer}:use_stochastic_weight_averaging'), True), - ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name) - )) + # Here we add custom code, like this with this + # is not a valid configuration + cs = self._add_forbidden_conditions(cs) self.configuration_space = cs self.dataset_properties = dataset_properties diff --git a/autoPyTorch/pipeline/tabular_regression.py b/autoPyTorch/pipeline/tabular_regression.py index 105506ba2..4737bf57d 100644 --- a/autoPyTorch/pipeline/tabular_regression.py +++ b/autoPyTorch/pipeline/tabular_regression.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union from ConfigSpace.configuration_space import Configuration, ConfigurationSpace -from ConfigSpace.forbidden import ForbiddenAndConjunction, ForbiddenEqualsClause import numpy as np @@ -210,56 +209,7 @@ def _get_hyperparameter_search_space(self, # Here we add custom code, like this with this # is not a valid configuration - # Learned Entity Embedding is only valid when encoder is one hot encoder - if 'network_embedding' in self.named_steps.keys() and 'encoder' in self.named_steps.keys(): - embeddings = cs.get_hyperparameter('network_embedding:__choice__').choices - if 'LearnedEntityEmbedding' in embeddings: - encoders = cs.get_hyperparameter('encoder:__choice__').choices - default = cs.get_hyperparameter('network_embedding:__choice__').default_value - possible_default_embeddings = copy.copy(list(embeddings)) - del possible_default_embeddings[possible_default_embeddings.index(default)] - - for encoder in encoders: - if encoder == 'OneHotEncoder': - continue - while True: - try: - cs.add_forbidden_clause(ForbiddenAndConjunction( - ForbiddenEqualsClause(cs.get_hyperparameter( - 'network_embedding:__choice__'), 'LearnedEntityEmbedding'), - ForbiddenEqualsClause(cs.get_hyperparameter('encoder:__choice__'), encoder) - )) - break - except ValueError: - # change the default and try again - try: - default = possible_default_embeddings.pop() - except IndexError: - raise ValueError("Cannot find a legal default configuration") - cs.get_hyperparameter('network_embedding:__choice__').default_value = default - - # Disable CyclicLR until todo is completed. - if 'lr_scheduler' in self.named_steps.keys() and 'trainer' in self.named_steps.keys(): - trainers = cs.get_hyperparameter('trainer:__choice__').choices - for trainer in trainers: - available_schedulers = self.named_steps['lr_scheduler'].get_available_components( - dataset_properties=dataset_properties, - exclude=exclude if bool(exclude) else None, - include=include if bool(include) else None) - # TODO: update cyclic lr to use n_restarts and adjust according to batch size - cyclic_lr_name = 'CyclicLR' - if cyclic_lr_name in available_schedulers: - # disable snapshot ensembles and stochastic weight averaging - cs.add_forbidden_clause(ForbiddenAndConjunction( - ForbiddenEqualsClause(cs.get_hyperparameter( - f'trainer:{trainer}:use_snapshot_ensemble'), True), - ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name) - )) - cs.add_forbidden_clause(ForbiddenAndConjunction( - ForbiddenEqualsClause(cs.get_hyperparameter( - f'trainer:{trainer}:use_stochastic_weight_averaging'), True), - ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name) - )) + cs = self._add_forbidden_conditions(cs) self.configuration_space = cs self.dataset_properties = dataset_properties From 0cc6b4696abcd2e912ebd4fc9a6c283ef5607b2e Mon Sep 17 00:00:00 2001 From: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> Date: Wed, 2 Mar 2022 17:11:30 +0100 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> --- .../components/setup/network_embedding/__init__.py | 6 ++++-- .../pipeline/components/training/trainer/base_trainer.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/autoPyTorch/pipeline/components/setup/network_embedding/__init__.py b/autoPyTorch/pipeline/components/setup/network_embedding/__init__.py index 381e0735d..6939842f4 100644 --- a/autoPyTorch/pipeline/components/setup/network_embedding/__init__.py +++ b/autoPyTorch/pipeline/components/setup/network_embedding/__init__.py @@ -155,8 +155,10 @@ def get_hyperparameter_search_space( default = default_ break - categorical_columns = dataset_properties['categorical_columns'] \ - if isinstance(dataset_properties['categorical_columns'], List) else [] + if isinstance(dataset_properties['categorical_columns'], list): + categorical_columns = dataset_properties['categorical_columns'] + else: + categorical_columns = [] updates = self._get_search_space_updates() if '__choice__' in updates.keys(): diff --git a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py index cf634953a..62768d374 100644 --- a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py @@ -338,8 +338,8 @@ def _swa_update(self) -> None: """ perform swa model update """ - assert self.swa_model is not None, "SWA model can't be none when" \ - " stochastic weight averaging is enabled" + if self.swa_model is None: + raise ValueError("SWA model cannot be none when stochastic weight averaging is enabled") self.swa_model.update_parameters(self.model) self.swa_updated = True @@ -350,8 +350,8 @@ def _se_update(self, epoch: int) -> None: epoch (int): current epoch """ - assert self.model_snapshots is not None, "model snapshots container can't be " \ - "none when snapshot ensembling is enabled" + if self.model_snapshots is None: + raise ValueError("model snapshots cannot be None when snapshot ensembling is enabled") is_last_epoch = (epoch == self.budget_tracker.max_epochs) if is_last_epoch and self.use_stochastic_weight_averaging: model_copy = deepcopy(self.swa_model)