-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
62 lines (50 loc) · 1.46 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import scipy.io as scio
import numpy as np
import logging
def generate_masks(mask_path):
mask = scio.loadmat(mask_path + '/mask.mat')
mask = mask['mask']
mask = np.transpose(mask, [2, 0, 1])
mask_s = np.sum(mask, axis=0)
index = np.where(mask_s == 0)
mask_s[index] = 1
mask_s = mask_s.astype(np.uint8)
mask = torch.from_numpy(mask)
mask = mask.float()
mask = mask.cuda()
mask_s = torch.from_numpy(mask_s)
mask_s = mask_s.float()
mask_s = mask_s.cuda()
return mask, mask_s
def time2file_name(time):
year = time[0:4]
month = time[5:7]
day = time[8:10]
hour = time[11:13]
minute = time[14:16]
second = time[17:19]
time_filename = year + '_' + month + '_' + day + '_' + hour + '_' + minute + '_' + second
return time_filename
def A(x,Phi):
temp = x*Phi
y = torch.sum(temp,1)
return y
def At(y,Phi):
temp = torch.unsqueeze(y, 1).repeat(1,Phi.shape[1],1,1)
x = temp*Phi
return x
def gen_log(model_path):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s: %(message)s")
log_file = model_path + '/log.txt'
fh = logging.FileHandler(log_file, mode='a')
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)
return logger