From 18a671842acdaaafa1e02b52724de9409bfb9117 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Tue, 28 Feb 2023 11:51:33 +0100 Subject: [PATCH 01/11] use torch native amp --- flair/trainers/trainer.py | 88 +++++++++++++++++++++------------------ 1 file changed, 48 insertions(+), 40 deletions(-) diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index c8ba7f09d..4bd533c67 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -9,6 +9,7 @@ from typing import List, Optional, Tuple, Type, Union import torch +from torch.optim.lr_scheduler import OneCycleLR # type: ignore from torch.optim.sgd import SGD from torch.utils.data.dataset import ConcatDataset @@ -276,6 +277,8 @@ def fine_tune( create_file_logs=create_file_logs, create_loss_file=create_loss_file, write_weights=write_weights, + # amp + use_amp=use_amp, # plugins plugins=plugins, **kwargs, @@ -322,43 +325,43 @@ def train_custom( ) -> dict: """Trains any class that implements the flair.nn.Model interface. - Parameters - ---------- - base_path: Main path to which all output during training is logged and models are saved - learning_rate (float): The learning rate of the optimizer - decoder_learning_rate (Optional[float]): Optional, if set, the decoder is trained with a separate learning rate - mini_batch_size (int): Size of mini-batches during training - eval_batch_size (int): Size of mini-batches during evaluation - mini_batch_chunk_size (int): If mini-batches are larger than this number, they get broken down into chunks of - this size for processing purposes - max_epochs (int): Maximum number of epochs to train. Terminates training if this number is surpassed. - optimizer: The optimizer to use (typically SGD or Adam) - train_with_dev (bool): If True, the data from dev split is added to the training data - train_with_test (bool): If True, the data from test split is added to the training data - main_evaluation_metric: The metric to optimize (often micro-average or macro-average F1-score, or accuracy) - monitor_test (bool): If True, test data is evaluated at end of each epoch - monitor_train_sample: Set this to evaluate on a sample of the train data at the end of each epoch. - If you set an int, it will sample this many sentences to evaluate on. If you set a float, it will sample - a percentage of data points from train. - use_final_model_for_eval (bool): If True, the final model is used for the final evaluation. If False, the - model from the best epoch as determined by main_evaluation_metric is used for the final evaluation. - gold_label_dictionary_for_eval: Set to force evaluation to use a particular label dictionary - exclude_labels: Optionally define a list of labels to exclude from the evaluation - sampler: You can pass a data sampler here for special sampling of data. - shuffle: If True, data is shuffled during training - shuffle_first_epoch: If True, data is shuffled during the first epoch of training - embeddings_storage_mode: One of 'none' (all embeddings are deleted and freshly recomputed), - 'cpu' (embeddings stored on CPU) or 'gpu' (embeddings stored on GPU) - epoch: The starting epoch (normally 0 but could be higher if you continue training model) - save_final_model: If True, the final model is saved at the end of training. - save_optimizer_state (bool): If True, the optimizer state is saved alongside the model - save_model_each_k_epochs: Each k epochs, a model state will be written out. If set to '5', a model will - be saved each 5 epochs. Default is 0 which means no model saving. - create_file_logs (bool): If True, logging output is written to a file - create_loss_file (bool): If True, a loss file logging output is created - write_weights (bool): If True, write weights to weights.txt on each batch logging event. - plugins: Any additional plugins you want to pass to the trainer - **kwargs: Additional arguments, for instance for the optimizer + Args: + base_path: Main path to which all output during training is logged and models are saved + learning_rate (float): The learning rate of the optimizer + decoder_learning_rate (Optional[float]): Optional, if set, the decoder is trained with a separate learning rate + mini_batch_size (int): Size of mini-batches during training + eval_batch_size (int): Size of mini-batches during evaluation + mini_batch_chunk_size (int): If mini-batches are larger than this number, they get broken down into chunks of + this size for processing purposes + max_epochs (int): Maximum number of epochs to train. Terminates training if this number is surpassed. + optimizer: The optimizer to use (typically SGD or Adam) + train_with_dev (bool): If True, the data from dev split is added to the training data + train_with_test (bool): If True, the data from test split is added to the training data + main_evaluation_metric: The metric to optimize (often micro-average or macro-average F1-score, or accuracy) + monitor_test (bool): If True, test data is evaluated at end of each epoch + monitor_train_sample: Set this to evaluate on a sample of the train data at the end of each epoch. + If you set an int, it will sample this many sentences to evaluate on. If you set a float, it will sample + a percentage of data points from train. + use_final_model_for_eval (bool): If True, the final model is used for the final evaluation. If False, the + model from the best epoch as determined by main_evaluation_metric is used for the final evaluation. + gold_label_dictionary_for_eval: Set to force evaluation to use a particular label dictionary + exclude_labels: Optionally define a list of labels to exclude from the evaluation + sampler: You can pass a data sampler here for special sampling of data. + shuffle: If True, data is shuffled during training + shuffle_first_epoch: If True, data is shuffled during the first epoch of training + embeddings_storage_mode: One of 'none' (all embeddings are deleted and freshly recomputed), + 'cpu' (embeddings stored on CPU) or 'gpu' (embeddings stored on GPU) + epoch: The starting epoch (normally 0 but could be higher if you continue training model) + save_final_model: If True, the final model is saved at the end of training. + save_optimizer_state (bool): If True, the optimizer state is saved alongside the model + save_model_each_k_epochs: Each k epochs, a model state will be written out. If set to '5', a model will + be saved each 5 epochs. Default is 0 which means no model saving. + create_file_logs (bool): If True, logging output is written to a file + create_loss_file (bool): If True, a loss file logging output is created + use_amp (bool): If True, uses the torch automatic mixed precision + write_weights (bool): If True, write weights to weights.txt on each batch logging event. + plugins: Any additional plugins you want to pass to the trainer + **kwargs: Additional arguments, for instance for the optimizer Returns: ------- @@ -471,6 +474,8 @@ def train_custom( # -- AnnealingPlugin -> initialize schedulers (requires instantiated optimizer) self.dispatch("after_setup", **parameters) + scaler = torch.cuda.amp.GradScaler(enabled=use_amp) + final_eval_info = ( "model after last epoch (final-model.pt)" if use_final_model_for_eval @@ -567,12 +572,13 @@ def train_custom( # forward and backward for batch for batch_step in batch_steps: # forward pass - loss, datapoint_count = self.model.forward_loss(batch_step) + with torch.autocast(device_type=flair.device.type, dtype=torch.float16, enabled=use_amp): + loss, datapoint_count = self.model.forward_loss(batch_step) batch_train_samples += datapoint_count batch_train_loss += loss.item() - self._backward(loss) + self._backward(scaler.scale(loss)) # identify dynamic embeddings (always deleted) on first sentence if dynamic_embeddings is None: @@ -584,8 +590,10 @@ def train_custom( self.dispatch("before_training_optimizer_step", **batch_kw) # do the optimizer step + scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0) - self.optimizer.step() + scaler.step(self.optimizer) + scaler.update() if batch_train_samples > 0: train_loss = batch_train_loss / batch_train_samples From b71e4edc6fd3d4f913e331721537a8709ac1e4e0 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Sat, 4 Mar 2023 03:30:15 +0100 Subject: [PATCH 02/11] fix training type to default device type --- flair/embeddings/document.py | 4 ++-- flair/embeddings/transformer.py | 2 +- flair/models/sequence_tagger_model.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flair/embeddings/document.py b/flair/embeddings/document.py index 2110e13bb..0fdfbcd70 100644 --- a/flair/embeddings/document.py +++ b/flair/embeddings/document.py @@ -333,7 +333,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): pre_allocated_zero_tensor = torch.zeros( self.embeddings.embedding_length * longest_token_sequence_in_batch, - dtype=torch.float, + dtype=self.rnn.all_weights[0][0].dtype, device=flair.device, ) @@ -691,7 +691,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): pre_allocated_zero_tensor = torch.zeros( self.embeddings.embedding_length * padding_length, - dtype=torch.float, + dtype=self.convs[0].weight.dtype, device=flair.device, ) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index c494c1cf8..bc36693e0 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -51,7 +51,7 @@ def pad_sequence_embeddings(all_hidden_states: List[torch.Tensor]) -> torch.Tens longest_token_sequence_in_batch = hidden_states.shape[0] pre_allocated_zero_tensor = torch.zeros( embedding_length * longest_token_sequence_in_batch, - dtype=torch.float, + dtype=all_hidden_states[0].dtype, device=flair.device, ) all_embs = [] diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 2353814b0..32d0ec11b 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -340,7 +340,7 @@ def _make_padded_tensor_for_batch(self, sentences: List[Sentence]) -> Tuple[torc longest_token_sequence_in_batch: int = max(lengths) pre_allocated_zero_tensor = torch.zeros( self.embeddings.embedding_length * longest_token_sequence_in_batch, - dtype=torch.float, + dtype=self.linear.weight.dtype, device=flair.device, ) all_embs = [] From 5d77f557e9cf455866a4b39969de234afe606922 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Sat, 4 Mar 2023 04:00:08 +0100 Subject: [PATCH 03/11] add support for `.half()` --- flair/embeddings/token.py | 2 +- flair/embeddings/transformer.py | 11 ++++++++--- flair/models/sequence_tagger_utils/viterbi.py | 4 ++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index 771c3a639..b553b08ea 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -525,7 +525,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): outputs = outputs.transpose(0, 1) chars_embeds_temp = torch.zeros( (outputs.size(0), outputs.size(2)), - dtype=torch.float, + dtype=outputs.dtype, device=flair.device, ) for i, index in enumerate(output_lengths): diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index bc36693e0..38cca673f 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -181,7 +181,9 @@ def fill_mean_token_embeddings( @torch.jit.script_if_tracing def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor): - result = torch.zeros(sentence_hidden_states.shape[0], sentence_hidden_states.shape[2]) + result = torch.zeros( + sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype + ) for i in torch.arange(sentence_hidden_states.shape[0]): result[i] = sentence_hidden_states[i, : sentence_lengths[i]].mean(dim=0) @@ -189,7 +191,9 @@ def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths @torch.jit.script_if_tracing def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor): - result = torch.zeros(sentence_hidden_states.shape[0], sentence_hidden_states.shape[2]) + result = torch.zeros( + sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype + ) for i in torch.arange(sentence_hidden_states.shape[0]): result[i], _ = sentence_hidden_states[i, : sentence_lengths[i]].max(dim=0) @@ -1328,7 +1332,8 @@ def forward( assert word_ids is not None assert token_lengths is not None all_token_embeddings = torch.zeros( # type: ignore[call-overload] - word_ids.shape[0], token_lengths.max(), self.embedding_length_internal, device=flair.device + word_ids.shape[0], token_lengths.max(), self.embedding_length_internal, device=flair.device, + dtype=sentence_hidden_states.dtype, ) true_tensor = torch.ones_like(word_ids[:, :1], dtype=torch.bool) if self.subtoken_pooling == "first": diff --git a/flair/models/sequence_tagger_utils/viterbi.py b/flair/models/sequence_tagger_utils/viterbi.py index f1d3fa256..1cae3c008 100644 --- a/flair/models/sequence_tagger_utils/viterbi.py +++ b/flair/models/sequence_tagger_utils/viterbi.py @@ -53,7 +53,7 @@ def forward(self, features_tuple: tuple, targets: torch.Tensor) -> torch.Tensor: ] gold_score = scores_at_targets.sum() + transitions_to_stop.sum() - scores_upto_t = torch.zeros(batch_size, self.tagset_size, device=flair.device) + scores_upto_t = torch.zeros(batch_size, self.tagset_size, device=flair.device, dtype=features.dtype) for t in range(max(lengths)): batch_size_t = sum( @@ -151,7 +151,7 @@ def decode( seq_len = features.size(1) # Create a tensor to hold accumulated sequence scores at each current tag - scores_upto_t = torch.zeros(batch_size, seq_len + 1, self.tagset_size).to(flair.device) + scores_upto_t = torch.zeros(batch_size, seq_len + 1, self.tagset_size, dtype=features.dtype).to(flair.device) # Create a tensor to hold back-pointers # i.e., indices of the previous_tag that corresponds to maximum accumulated score at current tag # Let pads be the tag index, since that was the last tag in the decoded sequence From 92e9e385a1bddae16d36d6b4286f2b71c98ec94c Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 6 Mar 2023 16:16:30 +0100 Subject: [PATCH 04/11] add AMP to the tutorial --- resources/docs/TUTORIAL_TRAINING_MORE.md | 25 ++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/resources/docs/TUTORIAL_TRAINING_MORE.md b/resources/docs/TUTORIAL_TRAINING_MORE.md index 9d16aa629..1889e9506 100644 --- a/resources/docs/TUTORIAL_TRAINING_MORE.md +++ b/resources/docs/TUTORIAL_TRAINING_MORE.md @@ -97,27 +97,36 @@ mini-batch size. Remember that this is the opposite of `mini_batch_size` so this ### Setting the Storage Mode of Embeddings -Another main parameter you need to set is the `embeddings_storage_mode` in the `train()` method of the `ModelTrainer`. It +Another main parameter you need to set is the `embeddings_storage_mode` in the `train()` method of the `ModelTrainer`. can have one of three values: -1. **'none'**: If you set `embeddings_storage_mode='none'`, embeddings do not get stored in memory. Instead they are +1. **'none'**: If you set `embeddings_storage_mode='none'`, embeddings do not get stored in memory. Instead, they are generated on-the-fly in each training mini-batch (during *training*). The main advantage is that this keeps your memory requirements low. Always set this if fine-tuning transformers. 2. **'cpu'**: If you set `embeddings_storage_mode='cpu'`, embeddings will get stored in regular memory. -* during *training*: this in many cases speeds things up significantly since embeddings only need to be computed in the - first epoch, after which they are just retrieved from memory. A disadvantage is that this increases memory - requirements. Depending on the size of your dataset and your memory setup, this option may not be possible. -* during *inference*: this slows down your inference when used with a GPU as embeddings need to be moved from GPU memory - to regular memory. The only reason to use this option during inference would be to not only use the predictions but - also the embeddings after prediction. + * during *training*: this in many cases speeds things up significantly since static embeddings only need to be computed in the + first epoch, after which they are just retrieved from memory. A disadvantage is that this increases memory + requirements. Depending on the size of your dataset and your memory setup, this option may not be possible. + * during *inference*: this slows down your inference when used with a GPU as embeddings need to be moved from GPU memory + to regular memory. The only reason to use this option during inference would be to not only use the predictions but + also the embeddings after prediction. 3. **'gpu'**: If you set `embeddings_storage_mode='gpu'`, embeddings will get stored in CUDA memory. This will often be the fastest one since this eliminates the need to shuffle tensors from CPU to CUDA over and over again. Of course, CUDA memory is often limited so large datasets will not fit into CUDA memory. However, if the dataset fits into CUDA memory, this option is the fastest one. +### Training with Automated Mixed Precision (AMP) + +A good way to speed up the training time and use less memory is [Automated Mixed Precision training](https://pytorch.org/docs/stable/amp.html). +Here calculations will be done with a smaller data type (for example by using *float16* instead of *float32*). That way +less memory is required and the training time is reduced by a good amount. +AMP can be activated by setting the `use_amp` parameter in the `train()` method of the `ModelTrainer` to `True` + +You can choose the data type for the AMP by using `torch.set_autocast_gpu_dtype(...)` or `torch.set_autocast_cpu_dtype(...)` respectively. + ### Reducing the memory food-print when using Transformers Especially when you use multilingual transformer embeddings such as [xlm-roberta-large](https://huggingface.co/xlm-roberta-large) or [xlm-v-base](facebook/xlm-v-base), From 7e3da045b8abe11f0e3b22bf6d1b42c35c240ed2 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 20 Mar 2023 19:46:45 +0100 Subject: [PATCH 05/11] use amp in language model trainer --- flair/trainers/language_model_trainer.py | 54 ++++++++++-------------- 1 file changed, 22 insertions(+), 32 deletions(-) diff --git a/flair/trainers/language_model_trainer.py b/flair/trainers/language_model_trainer.py index f5b6632e2..97915f07f 100644 --- a/flair/trainers/language_model_trainer.py +++ b/flair/trainers/language_model_trainer.py @@ -4,7 +4,7 @@ import random import time from pathlib import Path -from typing import Iterable, Optional, Type, Union +from typing import Any, Dict, Iterable, Optional, Type, Union import torch from torch import cuda @@ -13,16 +13,10 @@ from torch.optim.sgd import SGD from torch.utils.data import DataLoader, Dataset -from flair.optim import SGDW, ReduceLRWDOnPlateau - -try: - from apex import amp -except ImportError: - amp = None - import flair from flair.data import Dictionary from flair.models import LanguageModel +from flair.optim import SGDW, ReduceLRWDOnPlateau from flair.training_utils import add_file_handler log = logging.getLogger("flair") @@ -166,7 +160,8 @@ def __init__( epoch: int = 0, split: int = 0, loss: float = 10000, - optimizer_state: Optional[dict] = None, + optimizer_state: Optional[Dict[str, Any]] = None, + scaler_state: Optional[Dict[str, Any]] = None, ) -> None: self.model: LanguageModel = model self.optimizer: Type[Optimizer] = optimizer @@ -179,6 +174,7 @@ def __init__( self.split = split self.loss = loss self.optimizer_state = optimizer_state + self.scaler_state = scaler_state def train( self, @@ -194,16 +190,8 @@ def train( grow_to_sequence_length: int = 0, num_workers: int = 2, use_amp: bool = False, - amp_opt_level: str = "O1", **kwargs, ): - if use_amp and amp is None: - raise RuntimeError( - "Failed to import apex. Please install apex from " - "https://www.github.com/nvidia/apex " - "to enable mixed-precision training." - ) - # cast string to Path base_path = Path(base_path) @@ -229,6 +217,12 @@ def train( best_val_loss = self.loss kwargs["lr"] = learning_rate optimizer = self.optimizer(self.model.parameters(), **kwargs) + + scaler = torch.cuda.amp.GradScaler(enabled=use_amp and flair.device.type != "cpu") + + if self.scaler_state: + scaler.load_state_dict(self.scaler_state) + if self.optimizer_state is not None: optimizer.load_state_dict(self.optimizer_state) @@ -239,9 +233,6 @@ def train( else: scheduler = ReduceLROnPlateau(optimizer, verbose=True, factor=anneal_factor, patience=patience) - if use_amp: - self.model, optimizer = amp.initialize(self.model, optimizer, opt_level=amp_opt_level) - training_generator = DataLoader(self.corpus.train, shuffle=False, num_workers=num_workers) for epoch in range(self.epoch, max_epochs): @@ -296,24 +287,22 @@ def train( self.model.zero_grad() optimizer.zero_grad() + with torch.autocast(device_type=flair.device.type, enabled=use_amp): + # do the forward pass in the model + output, rnn_output, hidden = self.model.forward(data, hidden) - # do the forward pass in the model - output, rnn_output, hidden = self.model.forward(data, hidden) - - # try to predict the targets - loss = self.loss_function(output.view(-1, ntokens), targets) + # try to predict the targets + loss = self.loss_function(output.view(-1, ntokens), targets) # Backward - if use_amp: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + scaler.scale(loss).backward() + + scaler.unscale_(optimizer) # `clip_grad_norm` helps prevent the exploding gradient # problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip) - - optimizer.step() + scaler.step(optimizer) + scaler.update() total_loss += loss.data @@ -467,4 +456,5 @@ def load_checkpoint( split=checkpoint["split"], loss=checkpoint["loss"], optimizer_state=checkpoint["optimizer_state_dict"], + scaler_state=checkpoint.get("scaler_state_dict"), ) From 2cab4030a7cfbb439208bafa886db27e4fc9360c Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 24 Apr 2023 15:27:33 +0200 Subject: [PATCH 06/11] add missing use_amp parameter --- flair/trainers/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 4bd533c67..a6b3e1648 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -231,6 +231,8 @@ def fine_tune( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, + # amp + use_amp: bool = False, # plugins plugins: Optional[List[TrainerPlugin]] = None, attach_default_scheduler: bool = True, @@ -319,6 +321,8 @@ def train_custom( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, + # amp + use_amp: bool = False, # plugins plugins: List[TrainerPlugin] = [], **kwargs, From 295f1dbe28cf16aca5052d61d9dbe004e0dfdd37 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 24 Apr 2023 15:28:44 +0200 Subject: [PATCH 07/11] disable grad scaler on cpu to prevent warning --- flair/trainers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index a6b3e1648..8fb90c5f9 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -478,7 +478,7 @@ def train_custom( # -- AnnealingPlugin -> initialize schedulers (requires instantiated optimizer) self.dispatch("after_setup", **parameters) - scaler = torch.cuda.amp.GradScaler(enabled=use_amp) + scaler = torch.cuda.amp.GradScaler(enabled=use_amp and flair.device.type != "cpu") final_eval_info = ( "model after last epoch (final-model.pt)" From f43a53d870e9e12eac1afa8828f5d99173ee46c8 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 24 Apr 2023 15:31:48 +0200 Subject: [PATCH 08/11] fix unused import --- flair/trainers/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 8fb90c5f9..4c4f946f0 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -9,7 +9,6 @@ from typing import List, Optional, Tuple, Type, Union import torch -from torch.optim.lr_scheduler import OneCycleLR # type: ignore from torch.optim.sgd import SGD from torch.utils.data.dataset import ConcatDataset From eb288ff5cac2ad6e993da0a06b06db31e185f004 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 24 Apr 2023 16:11:49 +0200 Subject: [PATCH 09/11] remove fixing of autocast variable --- flair/trainers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 4c4f946f0..ce3a4c585 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -575,7 +575,7 @@ def train_custom( # forward and backward for batch for batch_step in batch_steps: # forward pass - with torch.autocast(device_type=flair.device.type, dtype=torch.float16, enabled=use_amp): + with torch.autocast(device_type=flair.device.type, enabled=use_amp): loss, datapoint_count = self.model.forward_loss(batch_step) batch_train_samples += datapoint_count From 91861bd14447388d326664b7b358e232d3b58b38 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 28 Apr 2023 21:50:49 +0200 Subject: [PATCH 10/11] reformatting after rebase --- flair/embeddings/transformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 38cca673f..9d2a8b5ab 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -1332,7 +1332,10 @@ def forward( assert word_ids is not None assert token_lengths is not None all_token_embeddings = torch.zeros( # type: ignore[call-overload] - word_ids.shape[0], token_lengths.max(), self.embedding_length_internal, device=flair.device, + word_ids.shape[0], + token_lengths.max(), + self.embedding_length_internal, + device=flair.device, dtype=sentence_hidden_states.dtype, ) true_tensor = torch.ones_like(word_ids[:, :1], dtype=torch.bool) From abcc7591ab2a8a4dc3827ac45c55b43d3ccbcd2c Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Sun, 30 Apr 2023 14:46:41 +0200 Subject: [PATCH 11/11] fix schduler steps without optimization --- flair/trainers/plugins/functional/linear_scheduler.py | 5 ++++- flair/trainers/trainer.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/flair/trainers/plugins/functional/linear_scheduler.py b/flair/trainers/plugins/functional/linear_scheduler.py index e9a19ae08..08aca32c2 100644 --- a/flair/trainers/plugins/functional/linear_scheduler.py +++ b/flair/trainers/plugins/functional/linear_scheduler.py @@ -61,12 +61,15 @@ def before_training_epoch(self, **kw): self.previous_learning_rate = self.current_learning_rate @TrainerPlugin.hook - def after_training_batch(self, **kw): + def after_training_batch(self, optimizer_was_run: bool, **kw): """Do the scheduler step if one-cycle or linear decay. :param kw: :return: """ + # skip if no optimization has happened. + if not optimizer_was_run: + return self.scheduler.step() self.store_learning_rate() diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index ce3a4c585..60b930438 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -595,8 +595,11 @@ def train_custom( # do the optimizer step scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0) + scale_before = scaler.get_scale() scaler.step(self.optimizer) scaler.update() + scale_after = scaler.get_scale() + batch_kw["optimizer_was_run"] = scale_before <= scale_after if batch_train_samples > 0: train_loss = batch_train_loss / batch_train_samples