Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shake Shake updates #287

Merged
merged 30 commits into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a5374ed
To test locally
ravinkohli Sep 2, 2021
06ad658
fix bug in trainer choice fit
ravinkohli Sep 6, 2021
1942279
fix ensemble bug
ravinkohli Sep 8, 2021
2dc8850
Correct bug in cleanup
ravinkohli Sep 8, 2021
a80eb9e
To test locally
ravinkohli Sep 2, 2021
87168db
Merge branch 'shake-even' of github.com:automl/Auto-PyTorch into shak…
ravinkohli Sep 10, 2021
06d80d4
Cleanup for removing time debug statements
ravinkohli Sep 16, 2021
d8b553a
ablation for adversarial
ravinkohli Sep 20, 2021
062de69
Merge branch 'cocktail_fixes_time_debug' of github.com:automl/Auto-Py…
ravinkohli Sep 20, 2021
34712b3
shuffle false in dataloader
ravinkohli Sep 21, 2021
49f40dc
drop last false in dataloader
ravinkohli Sep 21, 2021
f4ea158
fix bug for validation set, and cutout and cutmix
ravinkohli Sep 23, 2021
fca1399
To test locally
ravinkohli Sep 2, 2021
5d03fb2
Merge branch 'shake-even' of github.com:automl/Auto-PyTorch into shak…
ravinkohli Sep 24, 2021
209a4e8
shuffle = False
ravinkohli Sep 24, 2021
d18fcca
To test locally
ravinkohli Sep 2, 2021
b432882
Merge branch 'shake-even' of github.com:automl/Auto-PyTorch into shak…
ravinkohli Sep 24, 2021
b38bfb3
updates to search space
ravinkohli Sep 26, 2021
8c2f2ac
updates to search space
ravinkohli Sep 26, 2021
f0676b1
update branch with search space
ravinkohli Sep 26, 2021
82d950c
undo search space update
ravinkohli Sep 27, 2021
30ba55e
fix bug in shake shake flag
ravinkohli Sep 27, 2021
e406f5b
limit to shake-even
ravinkohli Sep 27, 2021
863cc06
restrict to even even
ravinkohli Sep 27, 2021
2921781
Add even even and others for shake-drop also
ravinkohli Sep 29, 2021
e9359da
fix bug in passing alpha beta method
ravinkohli Sep 29, 2021
7f25e6f
restrict to only even even
ravinkohli Sep 29, 2021
dd5cb5b
fix silly bug:
ravinkohli Sep 29, 2021
0bb8436
remove imputer and ordinal encoder for categorical transformer in fea…
ravinkohli Sep 30, 2021
89e595e
Address comments from shuhei
ravinkohli Sep 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def _clean_logger(self) -> None:
self.logging_server.join(timeout=5)
self.logging_server.terminate()
del self.stop_logging_server
self._logger = None

def _create_dask_client(self) -> None:
"""
Expand Down Expand Up @@ -491,6 +492,23 @@ def _load_models(self) -> bool:

return True

def _cleanup(self) -> None:
"""
Closes the different servers created during api search.
Returns:
None
"""
if hasattr(self, '_logger') and self._logger is not None:
self._logger.info("Closing the dask infrastructure")
self._close_dask_client()
self._logger.info("Finished closing the dask infrastructure")

# Clean up the logger
self._logger.info("Starting to clean up the logger")
self._clean_logger()
else:
self._close_dask_client()

def _load_best_individual_model(self) -> SingleBest:
"""
In case of failure during ensemble building,
Expand Down Expand Up @@ -923,6 +941,8 @@ def _search(
self._stopwatch.stop_task(traditional_task_name)

# ============> Starting ensemble
self.precision = precision
self.opt_metric = optimize_metric
elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name)
time_left_for_ensembles = max(0, total_walltime_limit - elapsed_time)
proc_ensemble = None
Expand Down Expand Up @@ -1024,18 +1044,12 @@ def _search(
pd.DataFrame(self.ensemble_performance_history).to_json(
os.path.join(self._backend.internals_directory, 'ensemble_history.json'))

self._logger.info("Closing the dask infrastructure")
self._close_dask_client()
self._logger.info("Finished closing the dask infrastructure")

if load_models:
self._logger.info("Loading models...")
self._load_models()
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
self._logger.info("Finished loading models...")

# Clean up the logger
self._logger.info("Starting to clean up the logger")
self._clean_logger()
self._cleanup()

return self

Expand Down Expand Up @@ -1230,7 +1244,7 @@ def fit_pipeline(self,
dataset_requirements = get_dataset_requirements(
info=self._get_required_dataset_properties(dataset))
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
self._backend.save_datamanager(dataset)
self._backend.replace_datamanager(dataset)

if self._logger is None:
self._logger = self._get_logger(dataset.dataset_name)
Expand Down Expand Up @@ -1506,7 +1520,7 @@ def predict(

predictions = self.ensemble_.predict(all_predictions)

self._clean_logger()
self._cleanup()
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved

return predictions

Expand Down Expand Up @@ -1543,10 +1557,7 @@ def __getstate__(self) -> Dict[str, Any]:
return self.__dict__

def __del__(self) -> None:
# Clean up the logger
self._clean_logger()

self._close_dask_client()
self._cleanup()

# When a multiprocessing work is done, the
# objects are deleted. We don't want to delete run areas
Expand Down
30 changes: 15 additions & 15 deletions autoPyTorch/data/tabular_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,26 @@ def get_tabular_preprocessors():
preprocessors['numerical'] = list()
preprocessors['categorical'] = list()

# preprocessors['categorical'].append(SimpleImputer(strategy='constant',
# # Train data is numpy
# # as of this point, where
# # Ordinal Encoding is using
# # for categorical. Only
# # Numbers are allowed
# # fill_value='!missing!',
# fill_value=-1,
# copy=False))

# preprocessors['categorical'].append(OrdinalEncoder(
# handle_unknown='use_encoded_value',
# unknown_value=-1))

preprocessors['categorical'].append(OneHotEncoder(
categories='auto',
sparse=False,
handle_unknown='ignore'))
preprocessors['categorical'].append(SimpleImputer(strategy='constant',
# Train data is numpy
# as of this point, where
# Ordinal Encoding is using
# for categorical. Only
# Numbers are allowed
# fill_value='!missing!',
fill_value=-1,
copy=False))

preprocessors['categorical'].append(OrdinalEncoder(
handle_unknown='use_encoded_value',
unknown_value=-1))

preprocessors['numerical'].append(SimpleImputer(strategy='median',
copy=False))
copy=False))
preprocessors['numerical'].append(StandardScaler(with_mean=True, with_std=True, copy=False))

return preprocessors
Expand Down
5 changes: 3 additions & 2 deletions autoPyTorch/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,12 +451,13 @@ def _check_search_space_updates(self, include: Optional[Dict[str, Any]],
continue
raise ValueError("Unknown hyperparameter for component {}. "
"Expected update hyperparameter "
"to be in {} got {}".format(node.__class__.__name__,
"to be in {} got {}. choice is {}".format(node.__class__.__name__,
component.
get_hyperparameter_search_space(
dataset_properties=self.dataset_properties).
get_hyperparameter_names(),
split_hyperparameter[1]))
split_hyperparameter[1],
component.__name__))
else:
if update.hyperparameter not in node.get_hyperparameter_search_space(
dataset_properties=self.dataset_properties):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = N
self.add_fit_requirements([
FitRequirement('numerical_columns', (List,), user_defined=True, dataset_property=True),
FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True)])
self.fit_time = None

def get_column_transformer(self) -> ColumnTransformer:
"""
Expand All @@ -48,7 +47,6 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "TabularColumnTransformer":
Returns:
"TabularColumnTransformer": an instance of self
"""
start_time = time.time()

self.check_requirements(X, y)
numerical_pipeline = 'passthrough'
Expand All @@ -74,7 +72,6 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "TabularColumnTransformer":
X_train = X['backend'].load_datamanager().train_tensors[0]

self.preprocessor.fit(X_train)
self.fit_time = time.time() - start_time

return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ def get_hyperparameter_search_space(
value_range=(True, False),
default_value=True,
),
shake_alpha_beta_method: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="shake_alpha_beta_method",
value_range=('shake-shake',
'shake-even',
'even-even',
'M3'),
default_value='shake-shake',
),
use_shake_drop: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="use_shake_drop",
value_range=(True, False),
default_value=True,
Expand Down Expand Up @@ -180,9 +188,8 @@ def get_hyperparameter_search_space(

if skip_connection_flag:

shake_drop_prob_flag = False
if 'shake-drop' in multi_branch_choice.value_range:
shake_drop_prob_flag = True
shake_shake_flag = 'shake-shake' in multi_branch_choice.value_range
shake_drop_prob_flag = 'shake-drop' in multi_branch_choice.value_range

mb_choice = get_hyperparameter(multi_branch_choice, CategoricalHyperparameter)
cs.add_hyperparameter(mb_choice)
Expand All @@ -192,6 +199,10 @@ def get_hyperparameter_search_space(
shake_drop_prob = get_hyperparameter(max_shake_drop_probability, UniformFloatHyperparameter)
cs.add_hyperparameter(shake_drop_prob)
cs.add_condition(CS.EqualsCondition(shake_drop_prob, mb_choice, "shake-drop"))
if shake_shake_flag or shake_drop_prob_flag:
method = get_hyperparameter(shake_alpha_beta_method, CategoricalHyperparameter)
cs.add_hyperparameter(method)
cs.add_condition(CS.InCondition(method, mb_choice, ["shake-shake", "shake-drop"]))

# It is the upper bound of the nr of groups,
# since the configuration will actually be sampled.
Expand Down Expand Up @@ -327,11 +338,14 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
if self.config["multi_branch_choice"] == 'shake-shake':
x1 = self.layers(x)
x2 = self.shake_shake_layers(x)
alpha, beta = shake_get_alpha_beta(self.training, x.is_cuda)
alpha, beta = shake_get_alpha_beta(is_training=self.training,
is_cuda=x.is_cuda,
method=self.config['shake_alpha_beta_method'])
x = shake_shake(x1, x2, alpha, beta)
elif self.config["multi_branch_choice"] == 'shake-drop':
x = self.layers(x)
alpha, beta = shake_get_alpha_beta(self.training, x.is_cuda)
alpha, beta = shake_get_alpha_beta(self.training, x.is_cuda,
method=self.config['shake_alpha_beta_method'])
bl = shake_drop_get_bl(
self.block_index,
1 - self.config["max_shake_drop_probability"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ def get_hyperparameter_search_space( # type: ignore[override]
'stairs'),
default_value='funnel',
),
shake_alpha_beta_method: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="shake_alpha_beta_method",
value_range=('shake-shake',
'shake-even',
'even-even',
'M3'),
default_value='shake-shake',
),
max_shake_drop_probability: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="max_shake_drop_probability",
value_range=(0, 1),
Expand Down Expand Up @@ -188,9 +196,8 @@ def get_hyperparameter_search_space( # type: ignore[override]

if skip_connection_flag:

shake_drop_prob_flag = False
if 'shake-drop' in multi_branch_choice.value_range:
shake_drop_prob_flag = True
shake_shake_flag = 'shake-shake' in multi_branch_choice.value_range
shake_drop_prob_flag = 'shake-drop' in multi_branch_choice.value_range

mb_choice = get_hyperparameter(multi_branch_choice, CategoricalHyperparameter)
cs.add_hyperparameter(mb_choice)
Expand All @@ -200,5 +207,9 @@ def get_hyperparameter_search_space( # type: ignore[override]
shake_drop_prob = get_hyperparameter(max_shake_drop_probability, UniformFloatHyperparameter)
cs.add_hyperparameter(shake_drop_prob)
cs.add_condition(CS.EqualsCondition(shake_drop_prob, mb_choice, "shake-drop"))
if shake_shake_flag or shake_drop_prob_flag:
method = get_hyperparameter(shake_alpha_beta_method, CategoricalHyperparameter)
cs.add_hyperparameter(method)
cs.add_condition(CS.InCondition(method, mb_choice, ["shake-shake", "shake-drop"]))

return cs
28 changes: 24 additions & 4 deletions autoPyTorch/pipeline/components/setup/network_backbone/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,35 @@ def backward(ctx: typing.Any,
shake_drop = ShakeDropFunction.apply


def shake_get_alpha_beta(is_training: bool, is_cuda: bool
) -> typing.Tuple[torch.tensor, torch.tensor]:
def shake_get_alpha_beta(
is_training: bool,
is_cuda: bool,
method: str
) -> typing.Tuple[torch.tensor, torch.tensor]:
"""
The methods used in this function have been introduced in 'ShakeShake Regularisation'
https://arxiv.org/abs/1705.07485. The names have been taken from the paper as well.
"""
if not is_training:
result = (torch.FloatTensor([0.5]), torch.FloatTensor([0.5]))
return result if not is_cuda else (result[0].cuda(), result[1].cuda())

# TODO implement other update methods
alpha = torch.rand(1)
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
beta = torch.rand(1)
if method == 'even-even':
alpha = torch.FloatTensor([0.5])
else:
alpha = torch.rand(1)

if method == 'shake-shake':
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
beta = torch.rand(1)
elif method in ['shake-even', 'even-even']:
beta = torch.FloatTensor([0.5])
elif method == 'M3':
beta = torch.FloatTensor(
[torch.rand(1)*(0.5 - alpha)*alpha if alpha < 0.5 else torch.rand(1)*(alpha - 0.5)*alpha]
)
else:
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Unknown method for ShakeShakeRegularisation in NetworkBackbone")

if is_cuda:
alpha = alpha.cuda()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def get_hyperparameter_search_space(
default_value=True,
),
weight_decay: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="weight_decay",
value_range=(1E-7, 0.1),
value_range=(1E-5, 0.1),
default_value=1E-4,
log=True),
log=False),
) -> ConfigurationSpace:
cs = ConfigurationSpace()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> torch.utils.data.DataLoader:
shuffle=True,
num_workers=X.get('num_workers', 0),
pin_memory=X.get('pin_memory', True),
drop_last=X.get('drop_last', True),
drop_last=X.get('drop_last', False),
collate_fn=custom_collate_fn,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,17 @@ def get_hyperparameter_search_space(
default_value=3),
epsilon: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="epsilon",
value_range=(0.05, 0.2),
default_value=0.2),
value_range=(0.001, 0.15),
default_value=0.007,
log=True),
) -> ConfigurationSpace:
cs = ConfigurationSpace()

epsilon = HyperparameterSearchSpace(hyperparameter="epsilon",
value_range=(0.007, 0.007),
default_value=0.007)
add_hyperparameter(cs, epsilon, UniformFloatHyperparameter)

add_hyperparameter(cs, use_stochastic_weight_averaging, CategoricalHyperparameter)
snapshot_ensemble_flag = False
if any(use_snapshot_ensemble.value_range):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}

size = X.shape[1]
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int32(size * lam)),
indices = torch.tensor(self.random_state.choice(range(size), max(1, np.int32(size * lam)),
replace=False))

X[:, indices] = X[index, :][:, indices]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}

size = X.shape[1]
indices = self.random_state.choice(range(1, size), max(1, np.int32(size * self.patch_ratio)),
indices = self.random_state.choice(range(size), max(1, np.int32(size * self.patch_ratio)),
replace=False)

"""if not isinstance(self.numerical_columns, typing.Iterable):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,12 @@ def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int,
outputs_data = list()
targets_data = list()

batch_load_start_time = time.time()
for step, (data, targets) in enumerate(train_loader):
self.data_loading_times.append(time.time() - batch_load_start_time)
batch_train_start = time.time()
if self.budget_tracker.is_max_time_reached():
break

loss, outputs = self.train_step(data, targets)

self.batch_fit_times.append(time.time() - batch_train_start)
# save for metric evaluation
outputs_data.append(outputs.detach().cpu())
targets_data.append(targets.detach().cpu())
Expand All @@ -395,7 +391,6 @@ def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int,
loss,
epoch * len(train_loader) + step,
)
batch_load_start_time = time.time()

if self.scheduler:
if 'ReduceLROnPlateau' in self.scheduler.__class__.__name__:
Expand Down
Loading