Skip to content

Commit

Permalink
Only apply DDP after manager is applied and weights are loaded (#1796)
Browse files Browse the repository at this point in the history
  • Loading branch information
anmarques authored Oct 30, 2023
1 parent d8e9a45 commit 882d01e
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,30 +401,27 @@ def collate_fn(batch):

_LOGGER.info("Creating model")
local_rank = int(os.environ["LOCAL_RANK"]) if args.distributed else None
model, arch_key, maybe_dp_device = _create_model(
model, arch_key = _create_model(
arch_key=args.arch_key,
local_rank=local_rank,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint_path,
pretrained_dataset=args.pretrained_dataset,
device=device,
num_classes=num_classes,
)

if args.distill_teacher not in ["self", "disable", None]:
_LOGGER.info("Instantiating teacher")
distill_teacher, _, _ = _create_model(
distill_teacher, _ = _create_model(
arch_key=args.teacher_arch_key,
local_rank=local_rank,
pretrained=True, # teacher is always pretrained
pretrained_dataset=args.pretrained_teacher_dataset,
checkpoint_path=args.distill_teacher,
device=device,
num_classes=num_classes,
)
else:
distill_teacher = args.distill_teacher
device = maybe_dp_device

if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down Expand Up @@ -507,7 +504,7 @@ def collate_fn(batch):
alpha = 1.0 - args.model_ema_decay
alpha = min(1.0, alpha * adjust)
model_ema = utils.ExponentialMovingAverage(
model, device=device, decay=1.0 - alpha
model, device=model.device, decay=1.0 - alpha
)

manager = checkpoint_manager = None
Expand Down Expand Up @@ -651,9 +648,17 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
args, optimizer, checkpoint=checkpoint, manager=manager
)

model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
ddp = True
device = local_rank
else:
ddp = False

model, device, _ = model_to_device(model, device, ddp)
if distill_teacher is not None:
distill_teacher, _, _ = model_to_device(distill_teacher, device, ddp)

if args.distributed:
model_without_ddp = model.module

best_top1_acc = -math.inf
Expand Down Expand Up @@ -760,7 +765,6 @@ def _create_model(
pretrained: Optional[bool] = False,
checkpoint_path: Optional[str] = None,
pretrained_dataset: Optional[str] = None,
device=None,
num_classes=None,
):
if not arch_key or arch_key in ModelRegistry.available_keys():
Expand Down Expand Up @@ -811,17 +815,7 @@ def _create_model(
raise ValueError(
f"Unable to find {arch_key} in ModelRegistry or in torchvision.models"
)
ddp = False
if local_rank is not None:
torch.cuda.set_device(local_rank)
device = local_rank
ddp = True
model, device, _ = model_to_device(
model=model,
device=device,
ddp=ddp,
)
return model, arch_key, device
return model, arch_key


def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None):
Expand Down

0 comments on commit 882d01e

Please sign in to comment.