forked from juefeix/pnn.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 6
/
utils.py
114 lines (98 loc) · 4.53 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import copy
import numpy as np
import math
import torch
import torch.nn as nn
def readtextfile(filename):
with open(filename) as f:
content = f.readlines()
f.close()
return content
def writetextfile(data, filename):
with open(filename, 'w') as f:
f.writelines(data)
f.close()
def delete_file(filename):
if os.path.isfile(filename) == True:
os.remove(filename)
def eformat(f, prec, exp_digits):
s = "%.*e"%(prec, f)
mantissa, exp = s.split('e')
# add 1 to digits as 1 is taken by sign +/-
return "%se%+0*d"%(mantissa, exp_digits+1, int(exp))
def saveargs(args):
path = args.logs
if os.path.isdir(path) == False:
os.makedirs(path)
with open(os.path.join(path,'args.txt'), 'w') as f:
for arg in vars(args):
f.write(arg+' '+str(getattr(args,arg))+'\n')
def init_params(net):
for m in net.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal(m.weight, mode='fan_out')
if m.bias:
nn.init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant(m.weight, 1)
nn.init.constant(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal(m.weight, std=1e-3)
if m.bias:
nn.init.constant(m.bias, 0)
def weights_init(m):
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class Counter: #not used currently
def __init__(self):
self.mask_size = 0
def update(self, size):
self.mask_size += size
def get_total(self):
return self.mask_size
def act_fn(act):
if act == 'relu':
act_ = nn.ReLU(inplace=False)
elif act == 'lrelu':
act_ = nn.LeakyReLU(inplace=True)
elif act == 'prelu':
act_ = nn.PReLU()
elif act == 'rrelu':
act_ = nn.RReLU(inplace=True)
elif act == 'elu':
act_ = nn.ELU(inplace=True)
elif act == 'selu':
act_ = nn.SELU(inplace=True)
elif act == 'tanh':
act_ = nn.Tanh()
elif act == 'sigmoid':
act_ = nn.Sigmoid()
else:
print('\n\nActivation function {} is not supported/understood\n\n'.format(act))
act_ = None
return act_
def print_values(x, noise, y, unique_masks, n=2):
np.set_printoptions(precision=5, linewidth=200, threshold=1000000, suppress=True)
print('\nimage: {} image0, channel0 {}'.format(list(x.unsqueeze(2).size()), x.unsqueeze(2).data[0, 0, 0, 0, :n].cpu().numpy()))
print('image: {} image0, channel1 {}'.format(list(x.unsqueeze(2).size()), x.unsqueeze(2).data[0, 1, 0, 0, :n].cpu().numpy()))
print('\nimage: {} image1, channel0 {}'.format(list(x.unsqueeze(2).size()), x.unsqueeze(2).data[1, 0, 0, 0, :n].cpu().numpy()))
print('image: {} image1, channel1 {}'.format(list(x.unsqueeze(2).size()), x.unsqueeze(2).data[1, 1, 0, 0, :n].cpu().numpy()))
if noise is not None:
print('\nnoise {} channel0, mask0: {}'.format(list(noise.size()), noise.data[0, 0, 0, 0, :n].cpu().numpy()))
print('noise {} channel0, mask1: {}'.format(list(noise.size()), noise.data[0, 0, 1, 0, :n].cpu().numpy()))
if unique_masks:
print('\nnoise {} channel1, mask0: {}'.format(list(noise.size()), noise.data[0, 1, 0, 0, :n].cpu().numpy()))
print('noise {} channel1, mask1: {}'.format(list(noise.size()), noise.data[0, 1, 1, 0, :n].cpu().numpy()))
print('\nmasks: {} image0, channel0, mask0: {}'.format(list(y.size()), y.data[0, 0, 0, 0, :n].cpu().numpy()))
print('masks: {} image0, channel0, mask1: {}'.format(list(y.size()), y.data[0, 0, 1, 0, :n].cpu().numpy()))
print('masks: {} image0, channel1, mask0: {}'.format(list(y.size()), y.data[0, 1, 0, 0, :n].cpu().numpy()))
print('masks: {} image0, channel1, mask1: {}'.format(list(y.size()), y.data[0, 1, 1, 0, :n].cpu().numpy()))
print('\nmasks: {} image1, channel0, mask0: {}'.format(list(y.size()), y.data[1, 0, 0, 0, :n].cpu().numpy()))
print('masks: {} image1, channel0, mask1: {}'.format(list(y.size()), y.data[1, 0, 1, 0, :n].cpu().numpy()))
print('masks: {} image1, channel1, mask0: {}'.format(list(y.size()), y.data[1, 1, 0, 0, :n].cpu().numpy()))
print('masks: {} image1, channel1, mask1: {}'.format(list(y.size()), y.data[1, 1, 1, 0, :n].cpu().numpy()))