Skip to content

Commit

Permalink
Merge pull request #3128 from flairNLP/use_torch_amp
Browse files Browse the repository at this point in the history
use torch native amp
  • Loading branch information
alanakbik authored Aug 8, 2023
2 parents 911d915 + abcc759 commit c96660c
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 91 deletions.
4 changes: 2 additions & 2 deletions flair/embeddings/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):

pre_allocated_zero_tensor = torch.zeros(
self.embeddings.embedding_length * longest_token_sequence_in_batch,
dtype=torch.float,
dtype=self.rnn.all_weights[0][0].dtype,
device=flair.device,
)

Expand Down Expand Up @@ -691,7 +691,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):

pre_allocated_zero_tensor = torch.zeros(
self.embeddings.embedding_length * padding_length,
dtype=torch.float,
dtype=self.convs[0].weight.dtype,
device=flair.device,
)

Expand Down
2 changes: 1 addition & 1 deletion flair/embeddings/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
outputs = outputs.transpose(0, 1)
chars_embeds_temp = torch.zeros(
(outputs.size(0), outputs.size(2)),
dtype=torch.float,
dtype=outputs.dtype,
device=flair.device,
)
for i, index in enumerate(output_lengths):
Expand Down
16 changes: 12 additions & 4 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def pad_sequence_embeddings(all_hidden_states: List[torch.Tensor]) -> torch.Tens
longest_token_sequence_in_batch = hidden_states.shape[0]
pre_allocated_zero_tensor = torch.zeros(
embedding_length * longest_token_sequence_in_batch,
dtype=torch.float,
dtype=all_hidden_states[0].dtype,
device=flair.device,
)
all_embs = []
Expand Down Expand Up @@ -181,15 +181,19 @@ def fill_mean_token_embeddings(

@torch.jit.script_if_tracing
def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor):
result = torch.zeros(sentence_hidden_states.shape[0], sentence_hidden_states.shape[2])
result = torch.zeros(
sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype
)

for i in torch.arange(sentence_hidden_states.shape[0]):
result[i] = sentence_hidden_states[i, : sentence_lengths[i]].mean(dim=0)


@torch.jit.script_if_tracing
def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor):
result = torch.zeros(sentence_hidden_states.shape[0], sentence_hidden_states.shape[2])
result = torch.zeros(
sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype
)

for i in torch.arange(sentence_hidden_states.shape[0]):
result[i], _ = sentence_hidden_states[i, : sentence_lengths[i]].max(dim=0)
Expand Down Expand Up @@ -1328,7 +1332,11 @@ def forward(
assert word_ids is not None
assert token_lengths is not None
all_token_embeddings = torch.zeros( # type: ignore[call-overload]
word_ids.shape[0], token_lengths.max(), self.embedding_length_internal, device=flair.device
word_ids.shape[0],
token_lengths.max(),
self.embedding_length_internal,
device=flair.device,
dtype=sentence_hidden_states.dtype,
)
true_tensor = torch.ones_like(word_ids[:, :1], dtype=torch.bool)
if self.subtoken_pooling == "first":
Expand Down
2 changes: 1 addition & 1 deletion flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def _make_padded_tensor_for_batch(self, sentences: List[Sentence]) -> Tuple[torc
longest_token_sequence_in_batch: int = max(lengths)
pre_allocated_zero_tensor = torch.zeros(
self.embeddings.embedding_length * longest_token_sequence_in_batch,
dtype=torch.float,
dtype=self.linear.weight.dtype,
device=flair.device,
)
all_embs = []
Expand Down
4 changes: 2 additions & 2 deletions flair/models/sequence_tagger_utils/viterbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def forward(self, features_tuple: tuple, targets: torch.Tensor) -> torch.Tensor:
]
gold_score = scores_at_targets.sum() + transitions_to_stop.sum()

scores_upto_t = torch.zeros(batch_size, self.tagset_size, device=flair.device)
scores_upto_t = torch.zeros(batch_size, self.tagset_size, device=flair.device, dtype=features.dtype)

for t in range(max(lengths)):
batch_size_t = sum(
Expand Down Expand Up @@ -151,7 +151,7 @@ def decode(
seq_len = features.size(1)

# Create a tensor to hold accumulated sequence scores at each current tag
scores_upto_t = torch.zeros(batch_size, seq_len + 1, self.tagset_size).to(flair.device)
scores_upto_t = torch.zeros(batch_size, seq_len + 1, self.tagset_size, dtype=features.dtype).to(flair.device)
# Create a tensor to hold back-pointers
# i.e., indices of the previous_tag that corresponds to maximum accumulated score at current tag
# Let pads be the <end> tag index, since that was the last tag in the decoded sequence
Expand Down
54 changes: 22 additions & 32 deletions flair/trainers/language_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
import time
from pathlib import Path
from typing import Iterable, Optional, Type, Union
from typing import Any, Dict, Iterable, Optional, Type, Union

import torch
from torch import cuda
Expand All @@ -13,16 +13,10 @@
from torch.optim.sgd import SGD
from torch.utils.data import DataLoader, Dataset

from flair.optim import SGDW, ReduceLRWDOnPlateau

try:
from apex import amp
except ImportError:
amp = None

import flair
from flair.data import Dictionary
from flair.models import LanguageModel
from flair.optim import SGDW, ReduceLRWDOnPlateau
from flair.training_utils import add_file_handler

log = logging.getLogger("flair")
Expand Down Expand Up @@ -166,7 +160,8 @@ def __init__(
epoch: int = 0,
split: int = 0,
loss: float = 10000,
optimizer_state: Optional[dict] = None,
optimizer_state: Optional[Dict[str, Any]] = None,
scaler_state: Optional[Dict[str, Any]] = None,
) -> None:
self.model: LanguageModel = model
self.optimizer: Type[Optimizer] = optimizer
Expand All @@ -179,6 +174,7 @@ def __init__(
self.split = split
self.loss = loss
self.optimizer_state = optimizer_state
self.scaler_state = scaler_state

def train(
self,
Expand All @@ -194,16 +190,8 @@ def train(
grow_to_sequence_length: int = 0,
num_workers: int = 2,
use_amp: bool = False,
amp_opt_level: str = "O1",
**kwargs,
):
if use_amp and amp is None:
raise RuntimeError(
"Failed to import apex. Please install apex from "
"https://www.github.com/nvidia/apex "
"to enable mixed-precision training."
)

# cast string to Path
base_path = Path(base_path)

Expand All @@ -229,6 +217,12 @@ def train(
best_val_loss = self.loss
kwargs["lr"] = learning_rate
optimizer = self.optimizer(self.model.parameters(), **kwargs)

scaler = torch.cuda.amp.GradScaler(enabled=use_amp and flair.device.type != "cpu")

if self.scaler_state:
scaler.load_state_dict(self.scaler_state)

if self.optimizer_state is not None:
optimizer.load_state_dict(self.optimizer_state)

Expand All @@ -239,9 +233,6 @@ def train(
else:
scheduler = ReduceLROnPlateau(optimizer, verbose=True, factor=anneal_factor, patience=patience)

if use_amp:
self.model, optimizer = amp.initialize(self.model, optimizer, opt_level=amp_opt_level)

training_generator = DataLoader(self.corpus.train, shuffle=False, num_workers=num_workers)

for epoch in range(self.epoch, max_epochs):
Expand Down Expand Up @@ -296,24 +287,22 @@ def train(

self.model.zero_grad()
optimizer.zero_grad()
with torch.autocast(device_type=flair.device.type, enabled=use_amp):
# do the forward pass in the model
output, rnn_output, hidden = self.model.forward(data, hidden)

# do the forward pass in the model
output, rnn_output, hidden = self.model.forward(data, hidden)

# try to predict the targets
loss = self.loss_function(output.view(-1, ntokens), targets)
# try to predict the targets
loss = self.loss_function(output.view(-1, ntokens), targets)
# Backward
if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
scaler.scale(loss).backward()

scaler.unscale_(optimizer)

# `clip_grad_norm` helps prevent the exploding gradient
# problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip)

optimizer.step()
scaler.step(optimizer)
scaler.update()

total_loss += loss.data

Expand Down Expand Up @@ -467,4 +456,5 @@ def load_checkpoint(
split=checkpoint["split"],
loss=checkpoint["loss"],
optimizer_state=checkpoint["optimizer_state_dict"],
scaler_state=checkpoint.get("scaler_state_dict"),
)
5 changes: 4 additions & 1 deletion flair/trainers/plugins/functional/linear_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@ def before_training_epoch(self, **kw):
self.previous_learning_rate = self.current_learning_rate

@TrainerPlugin.hook
def after_training_batch(self, **kw):
def after_training_batch(self, optimizer_was_run: bool, **kw):
"""Do the scheduler step if one-cycle or linear decay.
:param kw:
:return:
"""
# skip if no optimization has happened.
if not optimizer_was_run:
return
self.scheduler.step()
self.store_learning_rate()

Expand Down
94 changes: 54 additions & 40 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ def fine_tune(
create_file_logs: bool = True,
create_loss_file: bool = True,
write_weights: bool = False,
# amp
use_amp: bool = False,
# plugins
plugins: Optional[List[TrainerPlugin]] = None,
attach_default_scheduler: bool = True,
Expand Down Expand Up @@ -276,6 +278,8 @@ def fine_tune(
create_file_logs=create_file_logs,
create_loss_file=create_loss_file,
write_weights=write_weights,
# amp
use_amp=use_amp,
# plugins
plugins=plugins,
**kwargs,
Expand Down Expand Up @@ -316,49 +320,51 @@ def train_custom(
create_file_logs: bool = True,
create_loss_file: bool = True,
write_weights: bool = False,
# amp
use_amp: bool = False,
# plugins
plugins: List[TrainerPlugin] = [],
**kwargs,
) -> dict:
"""Trains any class that implements the flair.nn.Model interface.
Parameters
----------
base_path: Main path to which all output during training is logged and models are saved
learning_rate (float): The learning rate of the optimizer
decoder_learning_rate (Optional[float]): Optional, if set, the decoder is trained with a separate learning rate
mini_batch_size (int): Size of mini-batches during training
eval_batch_size (int): Size of mini-batches during evaluation
mini_batch_chunk_size (int): If mini-batches are larger than this number, they get broken down into chunks of
this size for processing purposes
max_epochs (int): Maximum number of epochs to train. Terminates training if this number is surpassed.
optimizer: The optimizer to use (typically SGD or Adam)
train_with_dev (bool): If True, the data from dev split is added to the training data
train_with_test (bool): If True, the data from test split is added to the training data
main_evaluation_metric: The metric to optimize (often micro-average or macro-average F1-score, or accuracy)
monitor_test (bool): If True, test data is evaluated at end of each epoch
monitor_train_sample: Set this to evaluate on a sample of the train data at the end of each epoch.
If you set an int, it will sample this many sentences to evaluate on. If you set a float, it will sample
a percentage of data points from train.
use_final_model_for_eval (bool): If True, the final model is used for the final evaluation. If False, the
model from the best epoch as determined by main_evaluation_metric is used for the final evaluation.
gold_label_dictionary_for_eval: Set to force evaluation to use a particular label dictionary
exclude_labels: Optionally define a list of labels to exclude from the evaluation
sampler: You can pass a data sampler here for special sampling of data.
shuffle: If True, data is shuffled during training
shuffle_first_epoch: If True, data is shuffled during the first epoch of training
embeddings_storage_mode: One of 'none' (all embeddings are deleted and freshly recomputed),
'cpu' (embeddings stored on CPU) or 'gpu' (embeddings stored on GPU)
epoch: The starting epoch (normally 0 but could be higher if you continue training model)
save_final_model: If True, the final model is saved at the end of training.
save_optimizer_state (bool): If True, the optimizer state is saved alongside the model
save_model_each_k_epochs: Each k epochs, a model state will be written out. If set to '5', a model will
be saved each 5 epochs. Default is 0 which means no model saving.
create_file_logs (bool): If True, logging output is written to a file
create_loss_file (bool): If True, a loss file logging output is created
write_weights (bool): If True, write weights to weights.txt on each batch logging event.
plugins: Any additional plugins you want to pass to the trainer
**kwargs: Additional arguments, for instance for the optimizer
Args:
base_path: Main path to which all output during training is logged and models are saved
learning_rate (float): The learning rate of the optimizer
decoder_learning_rate (Optional[float]): Optional, if set, the decoder is trained with a separate learning rate
mini_batch_size (int): Size of mini-batches during training
eval_batch_size (int): Size of mini-batches during evaluation
mini_batch_chunk_size (int): If mini-batches are larger than this number, they get broken down into chunks of
this size for processing purposes
max_epochs (int): Maximum number of epochs to train. Terminates training if this number is surpassed.
optimizer: The optimizer to use (typically SGD or Adam)
train_with_dev (bool): If True, the data from dev split is added to the training data
train_with_test (bool): If True, the data from test split is added to the training data
main_evaluation_metric: The metric to optimize (often micro-average or macro-average F1-score, or accuracy)
monitor_test (bool): If True, test data is evaluated at end of each epoch
monitor_train_sample: Set this to evaluate on a sample of the train data at the end of each epoch.
If you set an int, it will sample this many sentences to evaluate on. If you set a float, it will sample
a percentage of data points from train.
use_final_model_for_eval (bool): If True, the final model is used for the final evaluation. If False, the
model from the best epoch as determined by main_evaluation_metric is used for the final evaluation.
gold_label_dictionary_for_eval: Set to force evaluation to use a particular label dictionary
exclude_labels: Optionally define a list of labels to exclude from the evaluation
sampler: You can pass a data sampler here for special sampling of data.
shuffle: If True, data is shuffled during training
shuffle_first_epoch: If True, data is shuffled during the first epoch of training
embeddings_storage_mode: One of 'none' (all embeddings are deleted and freshly recomputed),
'cpu' (embeddings stored on CPU) or 'gpu' (embeddings stored on GPU)
epoch: The starting epoch (normally 0 but could be higher if you continue training model)
save_final_model: If True, the final model is saved at the end of training.
save_optimizer_state (bool): If True, the optimizer state is saved alongside the model
save_model_each_k_epochs: Each k epochs, a model state will be written out. If set to '5', a model will
be saved each 5 epochs. Default is 0 which means no model saving.
create_file_logs (bool): If True, logging output is written to a file
create_loss_file (bool): If True, a loss file logging output is created
use_amp (bool): If True, uses the torch automatic mixed precision
write_weights (bool): If True, write weights to weights.txt on each batch logging event.
plugins: Any additional plugins you want to pass to the trainer
**kwargs: Additional arguments, for instance for the optimizer
Returns:
-------
Expand Down Expand Up @@ -471,6 +477,8 @@ def train_custom(
# -- AnnealingPlugin -> initialize schedulers (requires instantiated optimizer)
self.dispatch("after_setup", **parameters)

scaler = torch.cuda.amp.GradScaler(enabled=use_amp and flair.device.type != "cpu")

final_eval_info = (
"model after last epoch (final-model.pt)"
if use_final_model_for_eval
Expand Down Expand Up @@ -567,12 +575,13 @@ def train_custom(
# forward and backward for batch
for batch_step in batch_steps:
# forward pass
loss, datapoint_count = self.model.forward_loss(batch_step)
with torch.autocast(device_type=flair.device.type, enabled=use_amp):
loss, datapoint_count = self.model.forward_loss(batch_step)

batch_train_samples += datapoint_count
batch_train_loss += loss.item()

self._backward(loss)
self._backward(scaler.scale(loss))

# identify dynamic embeddings (always deleted) on first sentence
if dynamic_embeddings is None:
Expand All @@ -584,8 +593,13 @@ def train_custom(
self.dispatch("before_training_optimizer_step", **batch_kw)

# do the optimizer step
scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
self.optimizer.step()
scale_before = scaler.get_scale()
scaler.step(self.optimizer)
scaler.update()
scale_after = scaler.get_scale()
batch_kw["optimizer_was_run"] = scale_before <= scale_after

if batch_train_samples > 0:
train_loss = batch_train_loss / batch_train_samples
Expand Down
Loading

0 comments on commit c96660c

Please sign in to comment.