Skip to content

Commit

Permalink
Change weighted loss to categorical and fix for test adversarial trai…
Browse files Browse the repository at this point in the history
…ner (#214)
  • Loading branch information
ravinkohli committed Feb 28, 2022
1 parent 1b5e1c3 commit 7f83ce3
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class AdversarialTrainer(BaseTrainerComponent):
def __init__(
self,
epsilon: float,
weighted_loss: int = 0,
weighted_loss: bool = False,
random_state: Optional[np.random.RandomState] = None,
use_stochastic_weight_averaging: bool = False,
use_snapshot_ensemble: bool = False,
Expand Down Expand Up @@ -159,8 +159,8 @@ def get_hyperparameter_search_space(
dataset_properties: Optional[Dict] = None,
weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="weighted_loss",
value_range=[1],
default_value=1),
value_range=[True, False],
default_value=True),
la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="la_steps",
value_range=(5, 10),
Expand Down Expand Up @@ -226,16 +226,9 @@ def get_hyperparameter_search_space(
parent_hyperparameter=parent_hyperparameter
)

"""
# TODO, decouple the weighted loss from the trainer
if dataset_properties is not None:
if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS:
add_hyperparameter(cs, weighted_loss, CategoricalHyperparameter)
"""
# TODO, decouple the weighted loss from the trainer. Uncomment the code above and
# remove the code below. Also update the method signature, so the weighted loss
# is not a constant.
if dataset_properties is not None:
if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS:
add_hyperparameter(cs, weighted_loss, Constant)

return cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@


class StandardTrainer(BaseTrainerComponent):
def __init__(self, weighted_loss: int = 0,
def __init__(self,
weighted_loss: bool = False,
use_stochastic_weight_averaging: bool = False,
use_snapshot_ensemble: bool = False,
se_lastk: int = 3,
Expand Down
17 changes: 5 additions & 12 deletions autoPyTorch/pipeline/components/training/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class BaseTrainerComponent(autoPyTorchTrainingComponent):
"""
Base class for training
Args:
weighted_loss (int, default=0): In case for classification, whether to weight
weighted_loss (bool, default=False): In case for classification, whether to weight
the loss function according to the distribution of classes in the target
use_stochastic_weight_averaging (bool, default=True): whether to use stochastic
weight averaging. Stochastic weight averaging is a simple average of
Expand All @@ -221,7 +221,7 @@ class BaseTrainerComponent(autoPyTorchTrainingComponent):
random_state:
**lookahead_config:
"""
def __init__(self, weighted_loss: int = 0,
def __init__(self, weighted_loss: bool = False,
use_stochastic_weight_averaging: bool = True,
use_snapshot_ensemble: bool = True,
se_lastk: int = 3,
Expand Down Expand Up @@ -587,8 +587,8 @@ def get_hyperparameter_search_space(
dataset_properties: Optional[Dict] = None,
weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="weighted_loss",
value_range=[1],
default_value=1),
value_range=[True, False],
default_value=True),
la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="la_steps",
value_range=(5, 10),
Expand Down Expand Up @@ -636,16 +636,9 @@ def get_hyperparameter_search_space(
parent_hyperparameter=parent_hyperparameter
)

"""
# TODO, decouple the weighted loss from the trainer
if dataset_properties is not None:
if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS:
add_hyperparameter(cs, weighted_loss, CategoricalHyperparameter)
"""
# TODO, decouple the weighted loss from the trainer. Uncomment the code above and
# remove the code below. Also update the method signature, so the weighted loss
# is not a constant.
if dataset_properties is not None:
if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS:
add_hyperparameter(cs, weighted_loss, Constant)

return cs
16 changes: 4 additions & 12 deletions autoPyTorch/pipeline/components/training/trainer/cutout_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class CutOut:
def __init__(self, patch_ratio: float,
cutout_prob: float,
weighted_loss: int = 0,
weighted_loss: bool = False,
random_state: Optional[np.random.RandomState] = None,
use_stochastic_weight_averaging: bool = False,
use_snapshot_ensemble: bool = False,
Expand Down Expand Up @@ -63,9 +63,8 @@ def get_hyperparameter_search_space(
dataset_properties: Optional[Dict] = None,
weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="weighted_loss",
value_range=[1],
default_value=1
),
value_range=[True, False],
default_value=True),
la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="la_steps",
value_range=(5, 10),
Expand Down Expand Up @@ -137,16 +136,9 @@ def get_hyperparameter_search_space(
parent_hyperparameter=parent_hyperparameter
)

"""
# TODO, decouple the weighted loss from the trainer
if dataset_properties is not None:
if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS:
add_hyperparameter(cs, weighted_loss, CategoricalHyperparameter)
"""
# TODO, decouple the weighted loss from the trainer. Uncomment the code above and
# remove the code below. Also update the method signature, so the weighted loss
# is not a constant.
if dataset_properties is not None:
if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS:
add_hyperparameter(cs, weighted_loss, Constant)

return cs
16 changes: 4 additions & 12 deletions autoPyTorch/pipeline/components/training/trainer/mixup_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class MixUp:
def __init__(self, alpha: float,
weighted_loss: int = 0,
weighted_loss: bool = False,
random_state: Optional[np.random.RandomState] = None,
use_stochastic_weight_averaging: bool = False,
use_snapshot_ensemble: bool = False,
Expand Down Expand Up @@ -61,9 +61,8 @@ def get_hyperparameter_search_space(
dataset_properties: Optional[Dict] = None,
weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="weighted_loss",
value_range=[1],
default_value=1
),
value_range=[True, False],
default_value=True),
la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="la_steps",
value_range=(5, 10),
Expand Down Expand Up @@ -128,16 +127,9 @@ def get_hyperparameter_search_space(
la_config_space,
parent_hyperparameter=parent_hyperparameter
)
"""
# TODO, decouple the weighted loss from the trainer
if dataset_properties is not None:
if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS:
add_hyperparameter(cs, weighted_loss, CategoricalHyperparameter)
"""
# TODO, decouple the weighted loss from the trainer. Uncomment the code above and
# remove the code below. Also update the method signature, so the weighted loss
# is not a constant.
if dataset_properties is not None:
if STRING_TO_TASK_TYPES[dataset_properties['task_type']] in CLASSIFICATION_TASKS:
add_hyperparameter(cs, weighted_loss, Constant)

return cs
2 changes: 1 addition & 1 deletion test/test_pipeline/test_tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def test_set_choices_updates(self, fit_dictionary_tabular):
@pytest.mark.parametrize('lr_scheduler', ['CosineAnnealingWarmRestarts',
'ReduceLROnPlateau'])
def test_trainer_cocktails(self, fit_dictionary_tabular, mocker, lr_scheduler, trainer): # noqa F811
fit_dictionary_tabular['epochs'] = 20
fit_dictionary_tabular['epochs'] = 45
fit_dictionary_tabular['early_stopping'] = 20
pipeline = TabularClassificationPipeline(
dataset_properties=fit_dictionary_tabular['dataset_properties'],
Expand Down

0 comments on commit 7f83ce3

Please sign in to comment.