-
Notifications
You must be signed in to change notification settings - Fork 0
/
demonstration.py
97 lines (66 loc) · 2.19 KB
/
demonstration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from utils import Logger
import utils
import networks
import networks_improve
import os
if __name__ == "__main__":
l = utils.Logger("demonstrate", "demonstrate")
architectures = ['VGAN', 'GAN_improve', 'dummy', 'DCGAN']
data = ['Cats', 'Cars']
print("""Choose architecture:
[1] Vanilla GAN
[2] Improved GAN (one-sided label smoothing, minibatch discrimination layer)
[3] Unrolled GAN (not yet implemented)
[4] DCGAN
""")
arch = architectures[int(input())-1]
print("""Choose data:
[1] Cats
[2] Cars (not yet implemented)
""")
dataname = data[int(input())-1]
if (arch == "VGAN"):
G = networks.GeneratorNet()
elif (arch == "GAN_improve"):
G = networks_improve.GeneratorNet()
elif(arch == "DCGAN"):
pass; # TODO: implement me later
for _,_,fnames in os.walk("./data/models/{}".format(arch)):
model_list = fnames #sorted(fnames, reverse = True)
model_list = filter(lambda x: "G" in x, model_list)
max_G = (sorted(model_list, key = lambda x: int(x[x.rindex("_")+1:]))[-1])
max_epoch = max_G[max_G.rindex("_")+1:]
print("Choose epoch: [0 - " + str(max_epoch) + "]")
epoch_num = input()
l.load_G(arch, dataname, int(epoch_num), G)
rows = 5
cols = 5
num_samples = 25
print("Enter number of iterations: ")
num_iters = int(input())
for itr in range(num_iters):
print("Iteration number " + str(itr))
if(arch == "GAN_improve"):
test_noise = networks_improve.noise(num_samples)
fake_data = G(test_noise).detach()
fake_imgs = networks_improve.vectors_to_images(fake_data)
elif(arch == "VGAN"):
test_noise = networks.noise(num_samples)
fake_data = G(test_noise).detach()
fake_imgs = networks.vectors_to_images(fake_data)
elif(arch == "DCGAN"):
pass;
else:
print("ERROR!");
exit();
fig = utils.plt.figure(figsize=(5,5))
for i in range(1, rows*cols+1):
img = fake_imgs[i-1]
fig.add_subplot(rows, cols, i, xticks=[], yticks=[])
if(arch == "GAN_improve"):
utils.plt.imshow(img.permute(1,2,0))
elif(arch == "VGAN"):
utils.plt.imshow(img.permute(1,2,0).squeeze(), cmap='gray') #TODO: adjust this for color
utils.plt.axis('off')
utils.plt.show()