-
Notifications
You must be signed in to change notification settings - Fork 85
/
utils.py
37 lines (35 loc) · 1.67 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from torch import nn
from collections import OrderedDict
def make_layers(block):
layers = []
for layer_name, v in block.items():
if 'pool' in layer_name:
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
layers.append((layer_name, layer))
elif 'deconv' in layer_name:
transposeConv2d = nn.ConvTranspose2d(in_channels=v[0],
out_channels=v[1],
kernel_size=v[2],
stride=v[3],
padding=v[4])
layers.append((layer_name, transposeConv2d))
if 'relu' in layer_name:
layers.append(('relu_' + layer_name, nn.ReLU(inplace=True)))
elif 'leaky' in layer_name:
layers.append(('leaky_' + layer_name,
nn.LeakyReLU(negative_slope=0.2, inplace=True)))
elif 'conv' in layer_name:
conv2d = nn.Conv2d(in_channels=v[0],
out_channels=v[1],
kernel_size=v[2],
stride=v[3],
padding=v[4])
layers.append((layer_name, conv2d))
if 'relu' in layer_name:
layers.append(('relu_' + layer_name, nn.ReLU(inplace=True)))
elif 'leaky' in layer_name:
layers.append(('leaky_' + layer_name,
nn.LeakyReLU(negative_slope=0.2, inplace=True)))
else:
raise NotImplementedError
return nn.Sequential(OrderedDict(layers))