Skip to content

Commit

Permalink
[add] documentation update in base trainer (#468)
Browse files Browse the repository at this point in the history
  • Loading branch information
theodorju authored Aug 12, 2022
1 parent afddca5 commit 34c704d
Showing 1 changed file with 181 additions and 54 deletions.
235 changes: 181 additions & 54 deletions autoPyTorch/pipeline/components/training/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,53 @@ def __init__(self,
An object for tracking when to stop the network training.
It handles epoch based criteria as well as training based criteria.
It also allows to define a 'epoch_or_time' budget type, which means,
the first of them both which is exhausted, is honored
It also allows to define a 'epoch_or_time' budget type, which means, the first of them both which is
exhausted, is honored
Args:
budget_type (str):
Type of budget to be used when fitting the pipeline.
Possible values are 'epochs', 'runtime', or 'epoch_or_time'
max_epochs (Optional[int], default=None):
Maximum number of epochs to train the pipeline for
max_runtime (Optional[int], default=None):
Maximum number of seconds to train the pipeline for
"""
self.start_time = time.time()
self.budget_type = budget_type
self.max_epochs = max_epochs
self.max_runtime = max_runtime

def is_max_epoch_reached(self, epoch: int) -> bool:
"""
For budget type 'epoch' or 'epoch_or_time' return True if the maximum number of epochs is reached.
Args:
epoch (int):
the current epoch
# Make None a method to run without this constrain
Returns:
bool:
True if the current epoch is larger than the maximum epochs, False otherwise.
Additionally, returns False if the run is without this constraint.
"""
# Make None a method to run without this constraint
if self.max_epochs is None:
return False
if self.budget_type in ['epochs', 'epoch_or_time'] and epoch > self.max_epochs:
return True
return False

def is_max_time_reached(self) -> bool:
# Make None a method to run without this constrain
"""
For budget type 'runtime' or 'epoch_or_time' return True if the maximum runtime is reached.
Returns:
bool:
True if the maximum runtime is reached, False otherwise.
Additionally, returns False if the run is without this constraint.
"""
# Make None a method to run without this constraint
if self.max_runtime is None:
return False
elapsed_time = time.time() - self.start_time
Expand All @@ -78,14 +106,22 @@ def __init__(
total_parameter_count: float,
trainable_parameter_count: float,
optimize_metric: Optional[str] = None,
):
) -> None:
"""
A useful object to track performance per epoch.
It allows to track train, validation and test information not only for
debug, but for research purposes (Like understanding overfit).
It allows to track train, validation and test information not only for debug, but for research purposes
(Like understanding overfit).
It does so by tracking a metric/loss at the end of each epoch.
Args:
total_parameter_count (float):
the total number of parameters of the model
trainable_parameter_count (float):
only the parameters being optimized
optimize_metric (Optional[str], default=None):
name of the metric that is used to evaluate a pipeline.
"""
self.performance_tracker: Dict[str, Dict] = {
'start_time': {},
Expand Down Expand Up @@ -121,8 +157,30 @@ def add_performance(self,
test_loss: Optional[float] = None,
) -> None:
"""
Tracks performance information about the run, useful for
plotting individual runs
Tracks performance information about the run, useful for plotting individual runs.
Args:
epoch (int):
the current epoch
start_time (float):
timestamp at the beginning of current epoch
end_time (float):
timestamp when gathering the information after the current epoch
train_loss (float):
the training loss
train_metrics (Dict[str, float]):
training scores for each desired metric
val_metrics (Dict[str, float]):
validation scores for each desired metric
test_metrics (Dict[str, float]):
test scores for each desired metric
val_loss (Optional[float], default=None):
the validation loss
test_loss (Optional[float], default=None):
the test loss
Returns:
None
"""
self.performance_tracker['train_loss'][epoch] = train_loss
self.performance_tracker['val_loss'][epoch] = val_loss
Expand All @@ -134,6 +192,18 @@ def add_performance(self,
self.performance_tracker['test_metrics'][epoch] = test_metrics

def get_best_epoch(self, split_type: str = 'val') -> int:
"""
Get the epoch with the best metric.
Args:
split_type (str, default=val):
Which split's metric to consider.
Possible values are 'train' or 'val
Returns:
int:
the epoch with the best metric
"""
# If we compute for optimization, prefer the performance
# metric to the loss
if self.optimize_metric is not None:
Expand All @@ -159,6 +229,13 @@ def get_best_epoch(self, split_type: str = 'val') -> int:
)) + 1 # Epochs start at 1

def get_last_epoch(self) -> int:
"""
Get the last epoch.
Returns:
int:
the last epoch
"""
if 'train_loss' not in self.performance_tracker:
return 0
else:
Expand All @@ -170,7 +247,8 @@ def repr_last_epoch(self) -> str:
performance
Returns:
str: A nice representation of the last epoch
str:
A nice representation of the last epoch
"""
last_epoch = len(self.performance_tracker['train_loss'])
string = "\n"
Expand Down Expand Up @@ -202,30 +280,43 @@ def is_empty(self) -> bool:
Checks if the object is empty or not
Returns:
bool
bool:
True if the object is empty, False otherwise
"""
# if train_loss is empty, we can be sure that RunSummary is empty.
return not bool(self.performance_tracker['train_loss'])


class BaseTrainerComponent(autoPyTorchTrainingComponent):
"""
Base class for training
Base class for training.
Args:
weighted_loss (int, default=0): In case for classification, whether to weight
the loss function according to the distribution of classes in the target
use_stochastic_weight_averaging (bool, default=True): whether to use stochastic
weight averaging. Stochastic weight averaging is a simple average of
multiple points(model parameters) along the trajectory of SGD. SWA
has been proposed in
weighted_loss (int, default=0):
In case for classification, whether to weight the loss function according to the distribution of classes
in the target
use_stochastic_weight_averaging (bool, default=True):
whether to use stochastic weight averaging. Stochastic weight averaging is a simple average of
multiple points(model parameters) along the trajectory of SGD. SWA has been proposed in
[Averaging Weights Leads to Wider Optima and Better Generalization](https://arxiv.org/abs/1803.05407)
use_snapshot_ensemble (bool, default=True): whether to use snapshot
ensemble
se_lastk (int, default=3): Number of snapshots of the network to maintain
use_lookahead_optimizer (bool, default=True): whether to use lookahead
optimizer
random_state:
**lookahead_config:
use_snapshot_ensemble (bool, default=True):
whether to use snapshot ensemble
se_lastk (int, default=3):
Number of snapshots of the network to maintain
use_lookahead_optimizer (bool, default=True):
whether to use lookahead optimizer
random_state (Optional[np.random.RandomState]):
Object that contains a seed and allows for reproducible results
swa_model (Optional[torch.nn.Module], default=None):
Averaged model used for Stochastic Weight Averaging
model_snapshots (Optional[List[torch.nn.Module]], default=None):
List of model snapshots in case snapshot ensemble is used
**lookahead_config (Any):
keyword arguments for the lookahead optimizer including:
la_steps (int):
number of lookahead steps
la_alpha (float):
linear interpolation factor. 1.0 recovers the inner optimizer.
"""
def __init__(self, weighted_loss: int = 0,
use_stochastic_weight_averaging: bool = True,
Expand Down Expand Up @@ -336,15 +427,21 @@ def prepare(

def on_epoch_start(self, X: Dict[str, Any], epoch: int) -> None:
"""
Optional place holder for AutoPytorch Extensions.
Optional placeholder for AutoPytorch Extensions.
A user can define what happens on every epoch start or every epoch end.
An user can define what happens on every epoch start or every epoch end.
Args:
X (Dict[str, Any]):
Dictionary with fitted parameters. It is a message passing mechanism, in which during a transform,
a components adds relevant information so that further stages can be properly fitted
epoch (int):
the current epoch
"""
pass

def _swa_update(self) -> None:
"""
perform swa model update
Perform Stochastic Weight Averaging model update
"""
if self.swa_model is None:
raise ValueError("SWA model cannot be none when stochastic weight averaging is enabled")
Expand All @@ -354,6 +451,7 @@ def _swa_update(self) -> None:
def _se_update(self, epoch: int) -> None:
"""
Add latest model or swa_model to model snapshot ensemble
Args:
epoch (int):
current epoch
Expand All @@ -373,9 +471,16 @@ def _se_update(self, epoch: int) -> None:

def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool:
"""
Optional place holder for AutoPytorch Extensions.
An user can define what happens on every epoch start or every epoch end.
If returns True, the training is stopped
Optional placeholder for AutoPytorch Extensions.
A user can define what happens on every epoch start or every epoch end.
If returns True, the training is stopped.
Args:
X (Dict[str, Any]):
Dictionary with fitted parameters. It is a message passing mechanism, in which during a transform,
a components adds relevant information so that further stages can be properly fitted
epoch (int):
the current epoch
"""
if X['is_cyclic_scheduler']:
Expand Down Expand Up @@ -421,12 +526,18 @@ def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int,
Train the model for a single epoch.
Args:
train_loader (torch.utils.data.DataLoader): generator of features/label
epoch (int): The current epoch used solely for tracking purposes
train_loader (torch.utils.data.DataLoader):
generator of features/label
epoch (int):
The current epoch used solely for tracking purposes
writer (Optional[SummaryWriter]):
Object to keep track of the training loss in an event file
Returns:
float: training loss
Dict[str, float]: scores for each desired metric
float:
training loss
Dict[str, float]:
scores for each desired metric
"""

loss_sum = 0.0
Expand Down Expand Up @@ -482,12 +593,16 @@ def train_step(self, data: torch.Tensor, targets: torch.Tensor) -> Tuple[float,
Allows to train 1 step of gradient descent, given a batch of train/labels
Args:
data (torch.Tensor): input features to the network
targets (torch.Tensor): ground truth to calculate loss
data (torch.Tensor):
input features to the network
targets (torch.Tensor):
ground truth to calculate loss
Returns:
torch.Tensor: The predictions of the network
float: the loss incurred in the prediction
torch.Tensor:
The predictions of the network
float:
the loss incurred in the prediction
"""
# prepare
data = data.float().to(self.device)
Expand All @@ -513,12 +628,18 @@ def evaluate(self, test_loader: torch.utils.data.DataLoader, epoch: int,
Evaluate the model in both metrics and criterion
Args:
test_loader (torch.utils.data.DataLoader): generator of features/label
epoch (int): the current epoch for tracking purposes
test_loader (torch.utils.data.DataLoader):
generator of features/label
epoch (int):
the current epoch for tracking purposes
writer (Optional[SummaryWriter]):
Object to keep track of the test loss in an event file
Returns:
float: test loss
Dict[str, float]: scores for each desired metric
float:
test loss
Dict[str, float]:
scores for each desired metric
"""
self.model.eval()

Expand Down Expand Up @@ -576,14 +697,15 @@ def get_class_weights(self, criterion: Type[torch.nn.Module], labels: Union[np.n
def data_preparation(self, X: torch.Tensor, y: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
"""
Depending on the trainer choice, data fed to the network might be pre-processed
on a different way. That is, in standard training we provide the data to the
network as we receive it to the loader. Some regularization techniques, like mixup
alter the data.
Depending on the trainer choice, data fed to the network might be pre-processed on a different way. That is,
in standard training we provide the data to the network as we receive it to the loader. Some regularization
techniques, like mixup alter the data.
Args:
X (torch.Tensor): The batch training features
y (torch.Tensor): The batch training labels
X (torch.Tensor):
The batch training features
y (torch.Tensor):
The batch training labels
Returns:
torch.Tensor: that processes data
Expand All @@ -595,16 +717,21 @@ def data_preparation(self, X: torch.Tensor, y: torch.Tensor,
def criterion_preparation(self, y_a: torch.Tensor, y_b: torch.Tensor = None, lam: float = 1.0
) -> Callable: # type: ignore
"""
Depending on the trainer choice, the criterion is not directly applied to the
traditional y_pred/y_ground_truth pairs, but rather it might have a slight transformation.
Depending on the trainer choice, the criterion is not directly applied to the traditional
y_pred/y_ground_truth pairs, but rather it might have a slight transformation.
For example, in the case of mixup training, we need to account for the lambda mixup
Args:
kwargs (Dict): an expanded dictionary with modifiers to the
criterion calculation
y_a (torch.Tensor):
the batch label of the first training example used in trainer
y_b (torch.Tensor, default=None):
if applicable, the batch label of the second training example used in trainer
lam (float):
trainer coefficient
Returns:
Callable: a lambda function that contains the new criterion calculation recipe
Callable:
a lambda function that contains the new criterion calculation recipe
"""
raise NotImplementedError()

Expand Down

0 comments on commit 34c704d

Please sign in to comment.