Skip to content

Commit

Permalink
Reg cocktails common paper modifications 2 (#417)
Browse files Browse the repository at this point in the history
* remove remaining differences

* Reg cocktails common paper modifications 5 (#418)

* add hasttr

* fix run summary
  • Loading branch information
ravinkohli committed Oct 25, 2022
1 parent 76dae54 commit d26c611
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
22 changes: 22 additions & 0 deletions autoPyTorch/evaluation/train_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
from multiprocessing.queues import Queue
import os
from typing import Any, Dict, List, Optional, Tuple, Union

from ConfigSpace.configuration_space import Configuration
Expand All @@ -21,6 +23,7 @@
)
from autoPyTorch.evaluation.utils import DisableFileOutputParameters
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
from autoPyTorch.utils.common import dict_repr, subsampler
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates

Expand Down Expand Up @@ -193,6 +196,25 @@ def fit_predict_and_loss(self) -> None:
additional_run_info = pipeline.get_additional_run_info() if hasattr(
pipeline, 'get_additional_run_info') else {}

# add learning curve of configurations to additional_run_info
if isinstance(pipeline, TabularClassificationPipeline):
if hasattr(pipeline.named_steps['trainer'], 'run_summary'):
run_summary = pipeline.named_steps['trainer'].run_summary
split_types = ['train', 'val', 'test']
run_summary_dict = dict(
run_summary={},
budget=self.budget,
seed=self.seed,
config_id=self.configuration.config_id,
num_run=self.num_run
)
for split_type in split_types:
run_summary_dict['run_summary'][f'{split_type}_loss'] = run_summary.performance_tracker.get(f'{split_type}_loss', None)
run_summary_dict['run_summary'][f'{split_type}_metrics'] = run_summary.performance_tracker.get(f'{split_type}_metrics', None)
self.logger.debug(f"run_summary_dict {json.dumps(run_summary_dict)}")
with open(os.path.join(self.backend.temporary_directory, 'run_summary.txt'), 'a') as file:
file.write(f"{json.dumps(run_summary_dict)}\n")

status = StatusType.SUCCESS

self.logger.debug("In train evaluator.fit_predict_and_loss, num_run: {} loss:{},"
Expand Down
4 changes: 2 additions & 2 deletions autoPyTorch/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,13 +351,13 @@ def _add_forbidden_conditions(self, cs: ConfigurationSpace) -> ConfigurationSpac
if cyclic_lr_name in available_schedulers:
# disable snapshot ensembles and stochastic weight averaging
snapshot_ensemble_hyperparameter = cs.get_hyperparameter(f'trainer:{trainer}:use_snapshot_ensemble')
if True in snapshot_ensemble_hyperparameter.choices:
if hasattr(snapshot_ensemble_hyperparameter, 'choices') and True in snapshot_ensemble_hyperparameter.choices:
cs.add_forbidden_clause(ForbiddenAndConjunction(
ForbiddenEqualsClause(snapshot_ensemble_hyperparameter, True),
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
))
swa_hyperparameter = cs.get_hyperparameter(f'trainer:{trainer}:use_stochastic_weight_averaging')
if True in swa_hyperparameter.choices:
if hasattr(swa_hyperparameter, 'choices') and True in swa_hyperparameter.choices:
cs.add_forbidden_clause(ForbiddenAndConjunction(
ForbiddenEqualsClause(swa_hyperparameter, True),
ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_hyperparameter_search_space(self,
"choices in {} got {}".format(self.__class__.__name__,
available_preprocessors,
choice_hyperparameter.value_range))
if len(choice_hyperparameter) == 0:
if len(categorical_columns) == 0:
assert len(choice_hyperparameter.value_range) == 1
assert 'NoEncoder' in choice_hyperparameter.value_range, \
"Provided {} in choices, however, the dataset " \
Expand Down

0 comments on commit d26c611

Please sign in to comment.