Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EarlyStopping feature #4576

Merged
merged 19 commits into from
Aug 28, 2021
19 changes: 18 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
from utils.downloads import attempt_download
from utils.loss import ComputeLoss
from utils.plots import plot_labels, plot_evolve
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device, \
torch_distributed_zero_first
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.metrics import fitness
from utils.loggers import Loggers
Expand Down Expand Up @@ -255,6 +256,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda)
stopper = EarlyStopping(patience=opt.patience)
compute_loss = ComputeLoss(model) # init loss class
LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
f'Using {train_loader.num_workers} dataloader workers\n'
Expand Down Expand Up @@ -389,6 +391,20 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
del ckpt
callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)

# Stop Single-GPU
if stopper(epoch=epoch, fitness=fi):
break

# Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
# stop = stopper(epoch=epoch, fitness=fi)
# if RANK == 0:
# dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks

# Stop DPP
# with torch_distributed_zero_first(RANK):
# if stop:
# break # must break all DDP ranks

# end epoch ----------------------------------------------------------------------------------------------------
# end training -----------------------------------------------------------------------------------------------------
if RANK in [-1, 0]:
Expand Down Expand Up @@ -454,6 +470,7 @@ def parse_opt(known=False):
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
parser.add_argument('--patience', type=int, default=30, help='EarlyStopping patience (epochs)')
opt = parser.parse_known_args()[0] if known else parser.parse_args()
return opt

Expand Down
17 changes: 17 additions & 0 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,23 @@ def copy_attr(a, b, include=(), exclude=()):
setattr(a, k, v)


class EarlyStopping:
# YOLOv5 simple early stopper
def __init__(self, patience=30):
self.best_fitness = 0.0 # i.e. mAP
self.best_epoch = 0
self.patience = patience # epochs to wait after fitness stops improving to stop

def __call__(self, epoch, fitness):
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
self.best_epoch = epoch
self.best_fitness = fitness
stop = (epoch - self.best_epoch) >= self.patience # stop training if patience exceeded
if stop:
LOGGER.info(f'EarlyStopping patience {self.patience} exceeded, stopping training.')
return stop


class ModelEMA:
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
Expand Down