From 882d01eda689cb0c43bd70c858c35a3f97103806 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 30 Oct 2023 17:24:40 -0400 Subject: [PATCH] Only apply DDP after manager is applied and weights are loaded (#1796) --- src/sparseml/pytorch/torchvision/train.py | 34 ++++++++++------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index d3620d503eb..d3b3a7a0e22 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -401,30 +401,27 @@ def collate_fn(batch): _LOGGER.info("Creating model") local_rank = int(os.environ["LOCAL_RANK"]) if args.distributed else None - model, arch_key, maybe_dp_device = _create_model( + model, arch_key = _create_model( arch_key=args.arch_key, local_rank=local_rank, pretrained=args.pretrained, checkpoint_path=args.checkpoint_path, pretrained_dataset=args.pretrained_dataset, - device=device, num_classes=num_classes, ) if args.distill_teacher not in ["self", "disable", None]: _LOGGER.info("Instantiating teacher") - distill_teacher, _, _ = _create_model( + distill_teacher, _ = _create_model( arch_key=args.teacher_arch_key, local_rank=local_rank, pretrained=True, # teacher is always pretrained pretrained_dataset=args.pretrained_teacher_dataset, checkpoint_path=args.distill_teacher, - device=device, num_classes=num_classes, ) else: distill_teacher = args.distill_teacher - device = maybe_dp_device if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -507,7 +504,7 @@ def collate_fn(batch): alpha = 1.0 - args.model_ema_decay alpha = min(1.0, alpha * adjust) model_ema = utils.ExponentialMovingAverage( - model, device=device, decay=1.0 - alpha + model, device=model.device, decay=1.0 - alpha ) manager = checkpoint_manager = None @@ -651,9 +648,17 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i args, optimizer, checkpoint=checkpoint, manager=manager ) - model_without_ddp = model if args.distributed: - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + ddp = True + device = local_rank + else: + ddp = False + + model, device, _ = model_to_device(model, device, ddp) + if distill_teacher is not None: + distill_teacher, _, _ = model_to_device(distill_teacher, device, ddp) + + if args.distributed: model_without_ddp = model.module best_top1_acc = -math.inf @@ -760,7 +765,6 @@ def _create_model( pretrained: Optional[bool] = False, checkpoint_path: Optional[str] = None, pretrained_dataset: Optional[str] = None, - device=None, num_classes=None, ): if not arch_key or arch_key in ModelRegistry.available_keys(): @@ -811,17 +815,7 @@ def _create_model( raise ValueError( f"Unable to find {arch_key} in ModelRegistry or in torchvision.models" ) - ddp = False - if local_rank is not None: - torch.cuda.set_device(local_rank) - device = local_rank - ddp = True - model, device, _ = model_to_device( - model=model, - device=device, - ddp=ddp, - ) - return model, arch_key, device + return model, arch_key def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None):