Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: nabenabe0928 <[email protected]>
  • Loading branch information
ravinkohli and nabenabe0928 authored Mar 2, 2022
1 parent cda66c8 commit 0cc6b46
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ def get_hyperparameter_search_space(
default = default_
break

categorical_columns = dataset_properties['categorical_columns'] \
if isinstance(dataset_properties['categorical_columns'], List) else []
if isinstance(dataset_properties['categorical_columns'], list):
categorical_columns = dataset_properties['categorical_columns']
else:
categorical_columns = []

updates = self._get_search_space_updates()
if '__choice__' in updates.keys():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,8 @@ def _swa_update(self) -> None:
"""
perform swa model update
"""
assert self.swa_model is not None, "SWA model can't be none when" \
" stochastic weight averaging is enabled"
if self.swa_model is None:
raise ValueError("SWA model cannot be none when stochastic weight averaging is enabled")
self.swa_model.update_parameters(self.model)
self.swa_updated = True

Expand All @@ -350,8 +350,8 @@ def _se_update(self, epoch: int) -> None:
epoch (int):
current epoch
"""
assert self.model_snapshots is not None, "model snapshots container can't be " \
"none when snapshot ensembling is enabled"
if self.model_snapshots is None:
raise ValueError("model snapshots cannot be None when snapshot ensembling is enabled")
is_last_epoch = (epoch == self.budget_tracker.max_epochs)
if is_last_epoch and self.use_stochastic_weight_averaging:
model_copy = deepcopy(self.swa_model)
Expand Down

0 comments on commit 0cc6b46

Please sign in to comment.