-
Notifications
You must be signed in to change notification settings - Fork 284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FIX] SWA and SE with non cyclic schedulers #395
Changes from all commits
8bf6280
989f3ac
cda66c8
0cc6b46
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -148,71 +148,64 @@ def get_hyperparameter_search_space( | |
if default is None: | ||
defaults = [ | ||
'NoEmbedding', | ||
# 'LearnedEntityEmbedding', | ||
'LearnedEntityEmbedding', | ||
] | ||
for default_ in defaults: | ||
if default_ in available_embedding: | ||
default = default_ | ||
break | ||
|
||
# Restrict embedding to NoEmbedding until preprocessing is fixed | ||
embedding = CSH.CategoricalHyperparameter('__choice__', | ||
['NoEmbedding'], | ||
default_value=default) | ||
if isinstance(dataset_properties['categorical_columns'], list): | ||
categorical_columns = dataset_properties['categorical_columns'] | ||
else: | ||
categorical_columns = [] | ||
|
||
updates = self._get_search_space_updates() | ||
if '__choice__' in updates.keys(): | ||
choice_hyperparameter = updates['__choice__'] | ||
if not set(choice_hyperparameter.value_range).issubset(available_embedding): | ||
raise ValueError("Expected given update for {} to have " | ||
"choices in {} got {}".format(self.__class__.__name__, | ||
available_embedding, | ||
choice_hyperparameter.value_range)) | ||
if len(categorical_columns) == 0: | ||
assert len(choice_hyperparameter.value_range) == 1 | ||
if 'NoEmbedding' not in choice_hyperparameter.value_range: | ||
raise ValueError("Provided {} in choices, however, the dataset " | ||
"is incompatible with it".format(choice_hyperparameter.value_range)) | ||
embedding = CSH.CategoricalHyperparameter('__choice__', | ||
choice_hyperparameter.value_range, | ||
default_value=choice_hyperparameter.default_value) | ||
else: | ||
|
||
if len(categorical_columns) == 0: | ||
default = 'NoEmbedding' | ||
if include is not None and default not in include: | ||
raise ValueError("Provided {} in include, however, the dataset " | ||
"is incompatible with it".format(include)) | ||
embedding = CSH.CategoricalHyperparameter('__choice__', | ||
['NoEmbedding'], | ||
default_value=default) | ||
else: | ||
embedding = CSH.CategoricalHyperparameter('__choice__', | ||
list(available_embedding.keys()), | ||
default_value=default) | ||
Comment on lines
+163
to
+192
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we create a method for this? def _get_embedding_config(self, updates, avail_embeddings, default: str = "NoEmbedding"):
if '__choice__' in updates.keys():
choices, default = updates['__choice__'].value_range, updates['__choice__'].default_value
if not set(choices).issubset(avail_embeddings):
raise ValueError(
f"The choices for {self.__class__.__name__} must be in {avail_embeddings}, but got {choices}"
)
if len(categorical_columns) == 0:
assert len(choices) == 1
if 'NoEmbedding' not in choices:
raise ValueError(f"The choices must include `NoEmbedding`, but got {choices}")
embedding = CSH.CategoricalHyperparameter('__choice__', choices, default_value=default)
elif len(categorical_columns) == 0:
default = 'NoEmbedding'
if include is not None and default not in include:
raise ValueError(f"default `{default}` must be in `include`: {include}")
embedding = CSH.CategoricalHyperparameter('__choice__', ['NoEmbedding'], default_value=default)
else:
choices = list(available_embedding.keys())
embedding = CSH.CategoricalHyperparameter('__choice__', choices, default_value=default)
return embedding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let me do that after we are done with the experiments. There is an issue raised for this. #377 . |
||
|
||
cs.add_hyperparameter(embedding) | ||
for name in embedding.choices: | ||
updates = self._get_search_space_updates(prefix=name) | ||
config_space = available_embedding[name].get_hyperparameter_search_space(dataset_properties, # type: ignore | ||
**updates) | ||
parent_hyperparameter = {'parent': embedding, 'value': name} | ||
cs.add_configuration_space( | ||
name, | ||
config_space, | ||
parent_hyperparameter=parent_hyperparameter | ||
) | ||
|
||
self.configuration_space_ = cs | ||
self.dataset_properties_ = dataset_properties | ||
return cs | ||
# categorical_columns = dataset_properties['categorical_columns'] \ | ||
# if isinstance(dataset_properties['categorical_columns'], List) else [] | ||
|
||
# updates = self._get_search_space_updates() | ||
# if '__choice__' in updates.keys(): | ||
# choice_hyperparameter = updates['__choice__'] | ||
# if not set(choice_hyperparameter.value_range).issubset(available_embedding): | ||
# raise ValueError("Expected given update for {} to have " | ||
# "choices in {} got {}".format(self.__class__.__name__, | ||
# available_embedding, | ||
# choice_hyperparameter.value_range)) | ||
# if len(categorical_columns) == 0: | ||
# assert len(choice_hyperparameter.value_range) == 1 | ||
# if 'NoEmbedding' not in choice_hyperparameter.value_range: | ||
# raise ValueError("Provided {} in choices, however, the dataset " | ||
# "is incompatible with it".format(choice_hyperparameter.value_range)) | ||
# embedding = CSH.CategoricalHyperparameter('__choice__', | ||
# choice_hyperparameter.value_range, | ||
# default_value=choice_hyperparameter.default_value) | ||
# else: | ||
|
||
# if len(categorical_columns) == 0: | ||
# default = 'NoEmbedding' | ||
# if include is not None and default not in include: | ||
# raise ValueError("Provided {} in include, however, the dataset " | ||
# "is incompatible with it".format(include)) | ||
# embedding = CSH.CategoricalHyperparameter('__choice__', | ||
# ['NoEmbedding'], | ||
# default_value=default) | ||
# else: | ||
# embedding = CSH.CategoricalHyperparameter('__choice__', | ||
# list(available_embedding.keys()), | ||
# default_value=default) | ||
|
||
# cs.add_hyperparameter(embedding) | ||
# for name in embedding.choices: | ||
# updates = self._get_search_space_updates(prefix=name) | ||
# config_space = available_embedding[name].get_hyperparameter_search_space( | ||
# dataset_properties, # type: ignore | ||
# **updates) | ||
# parent_hyperparameter = {'parent': embedding, 'value': name} | ||
# cs.add_configuration_space( | ||
# name, | ||
# config_space, | ||
# parent_hyperparameter=parent_hyperparameter | ||
# ) | ||
|
||
# self.configuration_space_ = cs | ||
# self.dataset_properties_ = dataset_properties | ||
# return cs | ||
|
||
def transform(self, X: np.ndarray) -> np.ndarray: | ||
assert self.choice is not None, "Cannot call transform before the object is initialized" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -334,6 +334,35 @@ def on_epoch_start(self, X: Dict[str, Any], epoch: int) -> None: | |
""" | ||
pass | ||
|
||
def _swa_update(self) -> None: | ||
""" | ||
perform swa model update | ||
""" | ||
if self.swa_model is None: | ||
raise ValueError("SWA model cannot be none when stochastic weight averaging is enabled") | ||
self.swa_model.update_parameters(self.model) | ||
self.swa_updated = True | ||
|
||
def _se_update(self, epoch: int) -> None: | ||
""" | ||
Add latest model or swa_model to model snapshot ensemble | ||
Args: | ||
epoch (int): | ||
current epoch | ||
""" | ||
if self.model_snapshots is None: | ||
raise ValueError("model snapshots cannot be None when snapshot ensembling is enabled") | ||
is_last_epoch = (epoch == self.budget_tracker.max_epochs) | ||
if is_last_epoch and self.use_stochastic_weight_averaging: | ||
model_copy = deepcopy(self.swa_model) | ||
else: | ||
model_copy = deepcopy(self.model) | ||
|
||
assert model_copy is not None | ||
model_copy.cpu() | ||
self.model_snapshots.append(model_copy) | ||
self.model_snapshots = self.model_snapshots[-self.se_lastk:] | ||
|
||
def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool: | ||
""" | ||
Optional place holder for AutoPytorch Extensions. | ||
|
@@ -344,39 +373,19 @@ def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool: | |
if X['is_cyclic_scheduler']: | ||
if hasattr(self.scheduler, 'T_cur') and self.scheduler.T_cur == 0 and epoch != 1: | ||
if self.use_stochastic_weight_averaging: | ||
assert self.swa_model is not None, "SWA model can't be none when" \ | ||
" stochastic weight averaging is enabled" | ||
self.swa_model.update_parameters(self.model) | ||
self.swa_updated = True | ||
self._swa_update() | ||
if self.use_snapshot_ensemble: | ||
assert self.model_snapshots is not None, "model snapshots container can't be " \ | ||
"none when snapshot ensembling is enabled" | ||
is_last_epoch = (epoch == self.budget_tracker.max_epochs) | ||
if is_last_epoch and self.use_stochastic_weight_averaging: | ||
model_copy = deepcopy(self.swa_model) | ||
else: | ||
model_copy = deepcopy(self.model) | ||
|
||
assert model_copy is not None | ||
model_copy.cpu() | ||
self.model_snapshots.append(model_copy) | ||
self.model_snapshots = self.model_snapshots[-self.se_lastk:] | ||
self._se_update(epoch=epoch) | ||
else: | ||
if epoch > self._budget_threshold: | ||
if self.use_stochastic_weight_averaging: | ||
assert self.swa_model is not None, "SWA model can't be none when" \ | ||
" stochastic weight averaging is enabled" | ||
self.swa_model.update_parameters(self.model) | ||
self.swa_updated = True | ||
if self.use_snapshot_ensemble: | ||
assert self.model_snapshots is not None, "model snapshots container can't be " \ | ||
"none when snapshot ensembling is enabled" | ||
model_copy = deepcopy(self.swa_model) if self.use_stochastic_weight_averaging \ | ||
else deepcopy(self.model) | ||
Comment on lines
-374
to
-375
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it ok to add the condition There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what do you mean? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. He caught a bug, when there was a wrong implementation initially for se and swa when both activated and it was fixed, it was not fixed for the case of a non-cyclic scheduler. Hence now that you fixed it, it is different from the previous implementation since you had no if last epoch check to add the swa model only in the end, but you add 3 copies of it throughout. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Basically, now it is ok and correct. |
||
assert model_copy is not None | ||
model_copy.cpu() | ||
self.model_snapshots.append(model_copy) | ||
self.model_snapshots = self.model_snapshots[-self.se_lastk:] | ||
if epoch > self._budget_threshold and self.use_stochastic_weight_averaging: | ||
self._swa_update() | ||
|
||
if ( | ||
self.use_snapshot_ensemble | ||
and self.budget_tracker.max_epochs is not None | ||
and epoch > (self.budget_tracker.max_epochs - self.se_lastk) | ||
): | ||
self._se_update(epoch=epoch) | ||
return False | ||
|
||
def _scheduler_step( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would we get an error here and why do we need to change the default?
It is not the case that we might have a default value that is not in the list of choices right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we will get a value error in case
LearnedEntityEmbedding
was the default value for the__choice__
hyperparameter