Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Documentation update in base_trainer.py #468

Merged
merged 1 commit into from
Aug 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 181 additions & 54 deletions autoPyTorch/pipeline/components/training/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,53 @@ def __init__(self,
An object for tracking when to stop the network training.
It handles epoch based criteria as well as training based criteria.

It also allows to define a 'epoch_or_time' budget type, which means,
the first of them both which is exhausted, is honored
It also allows to define a 'epoch_or_time' budget type, which means, the first of them both which is
exhausted, is honored

Args:
budget_type (str):
Type of budget to be used when fitting the pipeline.
Possible values are 'epochs', 'runtime', or 'epoch_or_time'
max_epochs (Optional[int], default=None):
Maximum number of epochs to train the pipeline for
max_runtime (Optional[int], default=None):
Maximum number of seconds to train the pipeline for
"""
self.start_time = time.time()
self.budget_type = budget_type
self.max_epochs = max_epochs
self.max_runtime = max_runtime

def is_max_epoch_reached(self, epoch: int) -> bool:
"""
For budget type 'epoch' or 'epoch_or_time' return True if the maximum number of epochs is reached.

Args:
epoch (int):
the current epoch

# Make None a method to run without this constrain
Returns:
bool:
True if the current epoch is larger than the maximum epochs, False otherwise.
Additionally, returns False if the run is without this constraint.
"""
# Make None a method to run without this constraint
if self.max_epochs is None:
return False
if self.budget_type in ['epochs', 'epoch_or_time'] and epoch > self.max_epochs:
return True
return False

def is_max_time_reached(self) -> bool:
# Make None a method to run without this constrain
"""
For budget type 'runtime' or 'epoch_or_time' return True if the maximum runtime is reached.

Returns:
bool:
True if the maximum runtime is reached, False otherwise.
Additionally, returns False if the run is without this constraint.
"""
# Make None a method to run without this constraint
if self.max_runtime is None:
return False
elapsed_time = time.time() - self.start_time
Expand All @@ -78,14 +106,22 @@ def __init__(
total_parameter_count: float,
trainable_parameter_count: float,
optimize_metric: Optional[str] = None,
):
) -> None:
"""
A useful object to track performance per epoch.

It allows to track train, validation and test information not only for
debug, but for research purposes (Like understanding overfit).
It allows to track train, validation and test information not only for debug, but for research purposes
(Like understanding overfit).

It does so by tracking a metric/loss at the end of each epoch.

Args:
total_parameter_count (float):
the total number of parameters of the model
trainable_parameter_count (float):
only the parameters being optimized
optimize_metric (Optional[str], default=None):
name of the metric that is used to evaluate a pipeline.
"""
self.performance_tracker: Dict[str, Dict] = {
'start_time': {},
Expand Down Expand Up @@ -121,8 +157,30 @@ def add_performance(self,
test_loss: Optional[float] = None,
) -> None:
"""
Tracks performance information about the run, useful for
plotting individual runs
Tracks performance information about the run, useful for plotting individual runs.

Args:
epoch (int):
the current epoch
start_time (float):
timestamp at the beginning of current epoch
end_time (float):
timestamp when gathering the information after the current epoch
train_loss (float):
the training loss
train_metrics (Dict[str, float]):
training scores for each desired metric
val_metrics (Dict[str, float]):
validation scores for each desired metric
test_metrics (Dict[str, float]):
test scores for each desired metric
val_loss (Optional[float], default=None):
the validation loss
test_loss (Optional[float], default=None):
the test loss

Returns:
None
"""
self.performance_tracker['train_loss'][epoch] = train_loss
self.performance_tracker['val_loss'][epoch] = val_loss
Expand All @@ -134,6 +192,18 @@ def add_performance(self,
self.performance_tracker['test_metrics'][epoch] = test_metrics

def get_best_epoch(self, split_type: str = 'val') -> int:
"""
Get the epoch with the best metric.

Args:
split_type (str, default=val):
Which split's metric to consider.
Possible values are 'train' or 'val

Returns:
int:
the epoch with the best metric
"""
# If we compute for optimization, prefer the performance
# metric to the loss
if self.optimize_metric is not None:
Expand All @@ -159,6 +229,13 @@ def get_best_epoch(self, split_type: str = 'val') -> int:
)) + 1 # Epochs start at 1

def get_last_epoch(self) -> int:
"""
Get the last epoch.

Returns:
int:
the last epoch
"""
if 'train_loss' not in self.performance_tracker:
return 0
else:
Expand All @@ -170,7 +247,8 @@ def repr_last_epoch(self) -> str:
performance

Returns:
str: A nice representation of the last epoch
str:
A nice representation of the last epoch
"""
last_epoch = len(self.performance_tracker['train_loss'])
string = "\n"
Expand Down Expand Up @@ -202,30 +280,43 @@ def is_empty(self) -> bool:
Checks if the object is empty or not

Returns:
bool
bool:
True if the object is empty, False otherwise
"""
# if train_loss is empty, we can be sure that RunSummary is empty.
return not bool(self.performance_tracker['train_loss'])


class BaseTrainerComponent(autoPyTorchTrainingComponent):
"""
Base class for training
Base class for training.

Args:
weighted_loss (int, default=0): In case for classification, whether to weight
the loss function according to the distribution of classes in the target
use_stochastic_weight_averaging (bool, default=True): whether to use stochastic
weight averaging. Stochastic weight averaging is a simple average of
multiple points(model parameters) along the trajectory of SGD. SWA
has been proposed in
weighted_loss (int, default=0):
In case for classification, whether to weight the loss function according to the distribution of classes
in the target
use_stochastic_weight_averaging (bool, default=True):
whether to use stochastic weight averaging. Stochastic weight averaging is a simple average of
multiple points(model parameters) along the trajectory of SGD. SWA has been proposed in
[Averaging Weights Leads to Wider Optima and Better Generalization](https://arxiv.org/abs/1803.05407)
use_snapshot_ensemble (bool, default=True): whether to use snapshot
ensemble
se_lastk (int, default=3): Number of snapshots of the network to maintain
use_lookahead_optimizer (bool, default=True): whether to use lookahead
optimizer
random_state:
**lookahead_config:
use_snapshot_ensemble (bool, default=True):
whether to use snapshot ensemble
se_lastk (int, default=3):
Number of snapshots of the network to maintain
use_lookahead_optimizer (bool, default=True):
whether to use lookahead optimizer
random_state (Optional[np.random.RandomState]):
Object that contains a seed and allows for reproducible results
swa_model (Optional[torch.nn.Module], default=None):
Averaged model used for Stochastic Weight Averaging
model_snapshots (Optional[List[torch.nn.Module]], default=None):
List of model snapshots in case snapshot ensemble is used
**lookahead_config (Any):
keyword arguments for the lookahead optimizer including:
la_steps (int):
number of lookahead steps
la_alpha (float):
linear interpolation factor. 1.0 recovers the inner optimizer.
"""
def __init__(self, weighted_loss: int = 0,
use_stochastic_weight_averaging: bool = True,
Expand Down Expand Up @@ -336,15 +427,21 @@ def prepare(

def on_epoch_start(self, X: Dict[str, Any], epoch: int) -> None:
"""
Optional place holder for AutoPytorch Extensions.
Optional placeholder for AutoPytorch Extensions.
A user can define what happens on every epoch start or every epoch end.

An user can define what happens on every epoch start or every epoch end.
Args:
X (Dict[str, Any]):
Dictionary with fitted parameters. It is a message passing mechanism, in which during a transform,
a components adds relevant information so that further stages can be properly fitted
epoch (int):
the current epoch
"""
pass

def _swa_update(self) -> None:
"""
perform swa model update
Perform Stochastic Weight Averaging model update
"""
if self.swa_model is None:
raise ValueError("SWA model cannot be none when stochastic weight averaging is enabled")
Expand All @@ -354,6 +451,7 @@ def _swa_update(self) -> None:
def _se_update(self, epoch: int) -> None:
"""
Add latest model or swa_model to model snapshot ensemble

Args:
epoch (int):
current epoch
Expand All @@ -373,9 +471,16 @@ def _se_update(self, epoch: int) -> None:

def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool:
"""
Optional place holder for AutoPytorch Extensions.
An user can define what happens on every epoch start or every epoch end.
If returns True, the training is stopped
Optional placeholder for AutoPytorch Extensions.
A user can define what happens on every epoch start or every epoch end.
If returns True, the training is stopped.

Args:
X (Dict[str, Any]):
Dictionary with fitted parameters. It is a message passing mechanism, in which during a transform,
a components adds relevant information so that further stages can be properly fitted
epoch (int):
the current epoch

"""
if X['is_cyclic_scheduler']:
Expand Down Expand Up @@ -421,12 +526,18 @@ def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int,
Train the model for a single epoch.

Args:
train_loader (torch.utils.data.DataLoader): generator of features/label
epoch (int): The current epoch used solely for tracking purposes
train_loader (torch.utils.data.DataLoader):
generator of features/label
epoch (int):
The current epoch used solely for tracking purposes
writer (Optional[SummaryWriter]):
Object to keep track of the training loss in an event file

Returns:
float: training loss
Dict[str, float]: scores for each desired metric
float:
training loss
Dict[str, float]:
scores for each desired metric
"""

loss_sum = 0.0
Expand Down Expand Up @@ -482,12 +593,16 @@ def train_step(self, data: torch.Tensor, targets: torch.Tensor) -> Tuple[float,
Allows to train 1 step of gradient descent, given a batch of train/labels

Args:
data (torch.Tensor): input features to the network
targets (torch.Tensor): ground truth to calculate loss
data (torch.Tensor):
input features to the network
targets (torch.Tensor):
ground truth to calculate loss

Returns:
torch.Tensor: The predictions of the network
float: the loss incurred in the prediction
torch.Tensor:
The predictions of the network
float:
the loss incurred in the prediction
"""
# prepare
data = data.float().to(self.device)
Expand All @@ -513,12 +628,18 @@ def evaluate(self, test_loader: torch.utils.data.DataLoader, epoch: int,
Evaluate the model in both metrics and criterion

Args:
test_loader (torch.utils.data.DataLoader): generator of features/label
epoch (int): the current epoch for tracking purposes
test_loader (torch.utils.data.DataLoader):
generator of features/label
epoch (int):
the current epoch for tracking purposes
writer (Optional[SummaryWriter]):
Object to keep track of the test loss in an event file

Returns:
float: test loss
Dict[str, float]: scores for each desired metric
float:
test loss
Dict[str, float]:
scores for each desired metric
"""
self.model.eval()

Expand Down Expand Up @@ -576,14 +697,15 @@ def get_class_weights(self, criterion: Type[torch.nn.Module], labels: Union[np.n
def data_preparation(self, X: torch.Tensor, y: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
"""
Depending on the trainer choice, data fed to the network might be pre-processed
on a different way. That is, in standard training we provide the data to the
network as we receive it to the loader. Some regularization techniques, like mixup
alter the data.
Depending on the trainer choice, data fed to the network might be pre-processed on a different way. That is,
in standard training we provide the data to the network as we receive it to the loader. Some regularization
techniques, like mixup alter the data.

Args:
X (torch.Tensor): The batch training features
y (torch.Tensor): The batch training labels
X (torch.Tensor):
The batch training features
y (torch.Tensor):
The batch training labels

Returns:
torch.Tensor: that processes data
Expand All @@ -595,16 +717,21 @@ def data_preparation(self, X: torch.Tensor, y: torch.Tensor,
def criterion_preparation(self, y_a: torch.Tensor, y_b: torch.Tensor = None, lam: float = 1.0
) -> Callable: # type: ignore
"""
Depending on the trainer choice, the criterion is not directly applied to the
traditional y_pred/y_ground_truth pairs, but rather it might have a slight transformation.
Depending on the trainer choice, the criterion is not directly applied to the traditional
y_pred/y_ground_truth pairs, but rather it might have a slight transformation.
For example, in the case of mixup training, we need to account for the lambda mixup

Args:
kwargs (Dict): an expanded dictionary with modifiers to the
criterion calculation
y_a (torch.Tensor):
the batch label of the first training example used in trainer
y_b (torch.Tensor, default=None):
if applicable, the batch label of the second training example used in trainer
lam (float):
trainer coefficient

Returns:
Callable: a lambda function that contains the new criterion calculation recipe
Callable:
a lambda function that contains the new criterion calculation recipe
"""
raise NotImplementedError()

Expand Down