Skip to content

Commit

Permalink
train: add fault tolerance + work to get rid of mmlab
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Sep 22, 2023
1 parent 29cad22 commit 8ebd18c
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 33 deletions.
32 changes: 32 additions & 0 deletions torchdrive/models/det.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,35 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
bboxes = bboxes.float().sigmoid() # normalized 0 to 1

return classes, bboxes


if __name__ == "__main__":
from functools import partial

m = BDD100KDet(device=torch.device("cpu"))
model = m.model.__self__
img, img_meta = m.transform(torch.rand(2, 3, 120, 240))
img_meta_list = [img_meta]
original_forward = model.forward

def forward(imgs):
return original_forward(
[imgs], img_metas=img_meta_list, return_loss=False, rescale=False
)

model.forward = forward
# model.forward = partial(
# model.forward,
# img_metas=img_meta_list,
# return_loss=False,
# rescale=False)
# model.forward.__globals__ = original_forward.__globals__
# model.forward.__code__ = original_forward.__code__
# model.forward.__closure__ = original_forward.__closure__

model(img)

# scripted = torch.jit.script(model, example_inputs=[[[img]]])
scripted = torch.jit.trace(model, img)

breakpoint()
2 changes: 1 addition & 1 deletion torchdrive/models/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
device: torch.device,
half: bool = True,
config: str = "upernet_convnext-t_fp16_512x1024_80k_sem_seg_bdd100k.py",
mmlab: bool = False,
mmlab: bool = True,
compile_fn: Callable[[nn.Module], nn.Module] = lambda m: m,
) -> None:
if device == torch.device("cpu"):
Expand Down
80 changes: 48 additions & 32 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,25 +388,58 @@ def cam_encoder() -> RegNetEncoder:
# scaler: amp.GradScaler = amp.GradScaler()
scaler: Optional[amp.GradScaler] = None

if args.load:
state_dict: Dict[str, torch.Tensor] = torch.load(
args.load, map_location=device, weights_only=True
global_step = 0
CHECKPOINT_PATH = os.path.join(args.output, "model.pt")
GLOBAL_STEP_KEY = "global_step"
MODEL_KEY = "model"
OPTIM_KEY = "optim"


def save(epoch: int) -> None:
if RANK != 0:
return
loss = epoch_loss / batch_idx if batch_idx else 0
torch.save(
{
MODEL_KEY: model.state_dict(),
OPTIM_KEY: optimizer.state_dict(),
"epoch": epoch,
GLOBAL_STEP_KEY: global_step,
"loss": loss,
},
path,
)
print(f"saved to {path}, loss = {loss}")


load_path = args.load

# new save format
if "optim" in state_dict:
if not args.skip_load_optim:
print("loading optim state_dict")
optim_dict: Dict[str, object] = state_dict["optim"] # pyre-fixme
optim_dict = transfer("optim_dict", optim_dict, device=torch.device("cpu"))
optimizer.load_state_dict(optim_dict)
LOAD_FAULT_TOLERANCE = os.path.exists(CHECKPOINT_PATH)

if LOAD_FAULT_TOLERANCE:
print(f"loading from fault tolerance checkpoint {CHECKPOINT_PATH}")
load_path = CHECKPOINT_PATH

if load_path:
ckpt: Dict[str, torch.Tensor] = torch.load(
load_path, map_location=device, weights_only=True
)

# NOTE: this overrides any LR set by schedulers
assert len(lr_groups) == len(optimizer.param_groups)
for lr, og in zip(lr_groups, optimizer.param_groups):
og["lr"] = lr
if not args.skip_load_optim or LOAD_FAULT_TOLERANCE:
print("loading optim state_dict")
optim_dict: Dict[str, object] = ckpt[OPTIM_KEY] # pyre-fixme
optim_dict = transfer("optim_dict", optim_dict, device=torch.device("cpu"))
optimizer.load_state_dict(optim_dict)

state_dict = state_dict["model"] # pyre-fixme
# NOTE: this overrides any LR set by schedulers
assert len(lr_groups) == len(optimizer.param_groups)
for lr, og in zip(lr_groups, optimizer.param_groups):
og["lr"] = lr

if GLOBAL_STEP_KEY in ckpt:
global_step = ckpt[GLOBAL_STEP_KEY]

state_dict = ckpt[MODEL_KEY] # pyre-fixme

# remap state_dict
state_dict = remap_state_dict(state_dict, model)
Expand All @@ -418,8 +451,6 @@ def cam_encoder() -> RegNetEncoder:
print(f"failed to load state_dict, err: {e}")


global_step = 0

meaned_losses: Dict[str, Union[float, torch.Tensor]] = {}


Expand All @@ -429,21 +460,6 @@ def reset_metrics() -> None:
loss_count = 0


def save(epoch: int) -> None:
if RANK != 0:
return
path = os.path.join(args.output, f"model_{epoch}.pt")
torch.save(
{
"model": model.state_dict(),
"optim": optimizer.state_dict(),
},
path,
)
l = epoch_loss / batch_idx if batch_idx else 0
print(f"saved to {path}, loss = {l}")


if args.profile: # and rank == 0:
prof: Optional[torch.profiler.profile] = torch.profiler.profile(
schedule=torch.profiler.schedule(wait=10, warmup=1, active=1, repeat=1),
Expand Down

0 comments on commit 8ebd18c

Please sign in to comment.