-
Notifications
You must be signed in to change notification settings - Fork 4
/
util.py
61 lines (46 loc) · 2.5 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import shutil
import logging
from typing import Type, List
from argparse import Namespace
from cosface_loss import MarginCosineProduct
def move_to_device(optimizer: Type[torch.optim.Optimizer], device: str):
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to(device)
def save_checkpoint(state: dict, is_best: bool, output_folder: str,
ckpt_filename: str = "last_checkpoint.pth"):
checkpoint_path = f"{output_folder}/{ckpt_filename}"
torch.save(state, checkpoint_path, )
if is_best:
torch.save(state["model_state_dict"], f"{output_folder}/best_model.pth")
def resume_train(args: Namespace, output_folder: str, model: torch.nn.Module,
model_optimizer: Type[torch.optim.Optimizer],
classifiers: List[MarginCosineProduct], classifiers_optimizers: List[Type[torch.optim.Optimizer]]):
"""Load model_, optimizer, and other training parameters"""
logging.info(f"Loading checkpoint: {args.resume_train}")
checkpoint = torch.load(args.resume_train)
start_epoch_num = checkpoint["epoch_num"]
model_state_dict = checkpoint["model_state_dict"]
model.load_state_dict(model_state_dict)
model = model.to(args.device)
model_optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# load classifiers and optimizers
assert args.groups_num*2 == len(classifiers) == len(classifiers_optimizers) == \
len(checkpoint["classifiers_state_dict"]) == len(checkpoint["optimizers_state_dict"]), \
(f"{args.groups_num}, {len(classifiers)}, {len(classifiers_optimizers)}, "
f"{len(checkpoint['classifiers_state_dict'])}, {len(checkpoint['optimizers_state_dict'])}")
for c, sd in zip(classifiers, checkpoint["classifiers_state_dict"]):
# Move classifiers to GPU before loading their optimizers
c = c.to(args.device)
c.load_state_dict(sd)
for c, sd in zip(classifiers_optimizers, checkpoint["optimizers_state_dict"]):
c.load_state_dict(sd)
for c in classifiers:
# Move classifiers back to CPU to save some GPU memory
c = c.cpu()
best_val_recall1 = checkpoint["best_val_recall1"]
# Copy best model_ to current output_folder
shutil.copy(args.resume_train.replace("last_checkpoint.pth", "best_model.pth"), output_folder)
return model, model_optimizer, classifiers, classifiers_optimizers, best_val_recall1, start_epoch_num