From 8ebd18ccf3c389f7ba42aa399b7d5df6977d3e96 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Thu, 21 Sep 2023 22:47:53 -0700 Subject: [PATCH] train: add fault tolerance + work to get rid of mmlab --- torchdrive/models/det.py | 32 ++++++++++++++ torchdrive/models/semantic.py | 2 +- train.py | 80 +++++++++++++++++++++-------------- 3 files changed, 81 insertions(+), 33 deletions(-) diff --git a/torchdrive/models/det.py b/torchdrive/models/det.py index ff18fc5..218249b 100644 --- a/torchdrive/models/det.py +++ b/torchdrive/models/det.py @@ -189,3 +189,35 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: bboxes = bboxes.float().sigmoid() # normalized 0 to 1 return classes, bboxes + + +if __name__ == "__main__": + from functools import partial + + m = BDD100KDet(device=torch.device("cpu")) + model = m.model.__self__ + img, img_meta = m.transform(torch.rand(2, 3, 120, 240)) + img_meta_list = [img_meta] + original_forward = model.forward + + def forward(imgs): + return original_forward( + [imgs], img_metas=img_meta_list, return_loss=False, rescale=False + ) + + model.forward = forward + # model.forward = partial( + # model.forward, + # img_metas=img_meta_list, + # return_loss=False, + # rescale=False) + # model.forward.__globals__ = original_forward.__globals__ + # model.forward.__code__ = original_forward.__code__ + # model.forward.__closure__ = original_forward.__closure__ + + model(img) + + # scripted = torch.jit.script(model, example_inputs=[[[img]]]) + scripted = torch.jit.trace(model, img) + + breakpoint() diff --git a/torchdrive/models/semantic.py b/torchdrive/models/semantic.py index e39a86e..b41605e 100644 --- a/torchdrive/models/semantic.py +++ b/torchdrive/models/semantic.py @@ -108,7 +108,7 @@ def __init__( device: torch.device, half: bool = True, config: str = "upernet_convnext-t_fp16_512x1024_80k_sem_seg_bdd100k.py", - mmlab: bool = False, + mmlab: bool = True, compile_fn: Callable[[nn.Module], nn.Module] = lambda m: m, ) -> None: if device == torch.device("cpu"): diff --git a/train.py b/train.py index 748b94e..4a4e2a6 100644 --- a/train.py +++ b/train.py @@ -388,25 +388,58 @@ def cam_encoder() -> RegNetEncoder: # scaler: amp.GradScaler = amp.GradScaler() scaler: Optional[amp.GradScaler] = None -if args.load: - state_dict: Dict[str, torch.Tensor] = torch.load( - args.load, map_location=device, weights_only=True +global_step = 0 +CHECKPOINT_PATH = os.path.join(args.output, "model.pt") +GLOBAL_STEP_KEY = "global_step" +MODEL_KEY = "model" +OPTIM_KEY = "optim" + + +def save(epoch: int) -> None: + if RANK != 0: + return + loss = epoch_loss / batch_idx if batch_idx else 0 + torch.save( + { + MODEL_KEY: model.state_dict(), + OPTIM_KEY: optimizer.state_dict(), + "epoch": epoch, + GLOBAL_STEP_KEY: global_step, + "loss": loss, + }, + path, ) + print(f"saved to {path}, loss = {loss}") + + +load_path = args.load - # new save format - if "optim" in state_dict: - if not args.skip_load_optim: - print("loading optim state_dict") - optim_dict: Dict[str, object] = state_dict["optim"] # pyre-fixme - optim_dict = transfer("optim_dict", optim_dict, device=torch.device("cpu")) - optimizer.load_state_dict(optim_dict) +LOAD_FAULT_TOLERANCE = os.path.exists(CHECKPOINT_PATH) + +if LOAD_FAULT_TOLERANCE: + print(f"loading from fault tolerance checkpoint {CHECKPOINT_PATH}") + load_path = CHECKPOINT_PATH + +if load_path: + ckpt: Dict[str, torch.Tensor] = torch.load( + load_path, map_location=device, weights_only=True + ) - # NOTE: this overrides any LR set by schedulers - assert len(lr_groups) == len(optimizer.param_groups) - for lr, og in zip(lr_groups, optimizer.param_groups): - og["lr"] = lr + if not args.skip_load_optim or LOAD_FAULT_TOLERANCE: + print("loading optim state_dict") + optim_dict: Dict[str, object] = ckpt[OPTIM_KEY] # pyre-fixme + optim_dict = transfer("optim_dict", optim_dict, device=torch.device("cpu")) + optimizer.load_state_dict(optim_dict) - state_dict = state_dict["model"] # pyre-fixme + # NOTE: this overrides any LR set by schedulers + assert len(lr_groups) == len(optimizer.param_groups) + for lr, og in zip(lr_groups, optimizer.param_groups): + og["lr"] = lr + + if GLOBAL_STEP_KEY in ckpt: + global_step = ckpt[GLOBAL_STEP_KEY] + + state_dict = ckpt[MODEL_KEY] # pyre-fixme # remap state_dict state_dict = remap_state_dict(state_dict, model) @@ -418,8 +451,6 @@ def cam_encoder() -> RegNetEncoder: print(f"failed to load state_dict, err: {e}") -global_step = 0 - meaned_losses: Dict[str, Union[float, torch.Tensor]] = {} @@ -429,21 +460,6 @@ def reset_metrics() -> None: loss_count = 0 -def save(epoch: int) -> None: - if RANK != 0: - return - path = os.path.join(args.output, f"model_{epoch}.pt") - torch.save( - { - "model": model.state_dict(), - "optim": optimizer.state_dict(), - }, - path, - ) - l = epoch_loss / batch_idx if batch_idx else 0 - print(f"saved to {path}, loss = {l}") - - if args.profile: # and rank == 0: prof: Optional[torch.profiler.profile] = torch.profiler.profile( schedule=torch.profiler.schedule(wait=10, warmup=1, active=1, repeat=1),