-
Notifications
You must be signed in to change notification settings - Fork 45
/
train.py
122 lines (107 loc) · 5.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import logging
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from importlib import import_module
from sam_lora_image_encoder import LoRA_Sam
from segment_anything import sam_model_registry
from trainer import trainer_synapse
from icecream import ic
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
default='/data/LarryXu/Synapse/preprocessed_data/train_npz', help='root dir for data')
parser.add_argument('--output', type=str, default='/output/sam/results')
parser.add_argument('--dataset', type=str,
default='Synapse', help='experiment_name')
parser.add_argument('--list_dir', type=str,
default='./lists/lists_Synapse', help='list dir')
parser.add_argument('--num_classes', type=int,
default=8, help='output channel of network')
parser.add_argument('--max_iterations', type=int,
default=30000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int,
default=200, help='maximum epoch number to train')
parser.add_argument('--stop_epoch', type=int,
default=160, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int,
default=12, help='batch_size per gpu')
parser.add_argument('--n_gpu', type=int, default=2, help='total gpu')
parser.add_argument('--deterministic', type=int, default=1,
help='whether use deterministic training')
parser.add_argument('--base_lr', type=float, default=0.005,
help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
default=512, help='input patch size of network input')
parser.add_argument('--seed', type=int,
default=1234, help='random seed')
parser.add_argument('--vit_name', type=str,
default='vit_b', help='select one vit model')
parser.add_argument('--ckpt', type=str, default='checkpoints/sam_vit_b_01ec64.pth',
help='Pretrained checkpoint')
parser.add_argument('--lora_ckpt', type=str, default=None, help='Finetuned lora checkpoint')
parser.add_argument('--rank', type=int, default=4, help='Rank for LoRA adaptation')
parser.add_argument('--warmup', action='store_true', help='If activated, warp up the learning from a lower lr to the base_lr')
parser.add_argument('--warmup_period', type=int, default=250,
help='Warp up iterations, only valid whrn warmup is activated')
parser.add_argument('--AdamW', action='store_true', help='If activated, use AdamW to finetune SAM model')
parser.add_argument('--module', type=str, default='sam_lora_image_encoder')
parser.add_argument('--dice_param', type=float, default=0.8)
args = parser.parse_args()
if __name__ == "__main__":
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_name = args.dataset
dataset_config = {
'Synapse': {
'root_path': args.root_path,
'list_dir': args.list_dir,
'num_classes': args.num_classes,
}
}
args.is_pretrain = True
args.exp = dataset_name + '_' + str(args.img_size)
snapshot_path = os.path.join(args.output, "{}".format(args.exp))
snapshot_path = snapshot_path + '_pretrain' if args.is_pretrain else snapshot_path
snapshot_path += '_' + args.vit_name
snapshot_path = snapshot_path + '_' + str(args.max_iterations)[
0:2] + 'k' if args.max_iterations != 30000 else snapshot_path
snapshot_path = snapshot_path + '_epo' + str(args.max_epochs) if args.max_epochs != 30 else snapshot_path
snapshot_path = snapshot_path + '_bs' + str(args.batch_size)
snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.01 else snapshot_path
snapshot_path = snapshot_path + '_s' + str(args.seed) if args.seed != 1234 else snapshot_path
if not os.path.exists(snapshot_path):
os.makedirs(snapshot_path)
# register model
sam, img_embedding_size = sam_model_registry[args.vit_name](image_size=args.img_size,
num_classes=args.num_classes,
checkpoint=args.ckpt, pixel_mean=[0, 0, 0],
pixel_std=[1, 1, 1])
pkg = import_module(args.module)
net = pkg.LoRA_Sam(sam, args.rank).cuda()
# net = LoRA_Sam(sam, args.rank).cuda()
if args.lora_ckpt is not None:
net.load_lora_parameters(args.lora_ckpt)
if args.num_classes > 1:
multimask_output = True
else:
multimask_output = False
low_res = img_embedding_size * 4
config_file = os.path.join(snapshot_path, 'config.txt')
config_items = []
for key, value in args.__dict__.items():
config_items.append(f'{key}: {value}\n')
with open(config_file, 'w') as f:
f.writelines(config_items)
trainer = {'Synapse': trainer_synapse}
trainer[dataset_name](args, net, snapshot_path, multimask_output, low_res)