-
Notifications
You must be signed in to change notification settings - Fork 4
/
vae.py
127 lines (116 loc) · 4.7 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# First check the Python version
import sys
if sys.version_info < (3,4):
print('You are running an older version of Python!\n\n' \
'You should consider updating to Python 3.4.0 or ' \
'higher as the libraries built for this course ' \
'have only been tested in Python 3.4 and higher.\n')
print('Try installing the Python 3.5 version of anaconda '
'and then restart `jupyter notebook`:\n' \
'https://www.continuum.io/downloads\n\n')
# Now get necessary libraries
try:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from skimage.transform import resize
from skimage import data
from scipy.misc import imresize
import IPython.display as ipyd
except ImportError:
print('You are missing some packages! ' \
'We will try installing them before continuing!')
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from skimage.transform import resize
from skimage import data
from scipy.misc import imresize
import IPython.display as ipyd
print('Done!')
# Import Tensorflow
try:
import tensorflow as tf
except ImportError:
print("You do not have tensorflow installed!")
print("Follow the instructions on the following link")
print("to install tensorflow before continuing:")
print("")
print("https://github.com/pkmital/CADL#installation-preliminaries")
# This cell includes the provided libraries from the zip file
# and a library for displaying images from ipython, which
# we will use to display the gif
try:
from libs import utils, gif, datasets, dataset_utils, vae, dft
except ImportError:
print("Make sure you have started notebook in the same directory" +
" as the provided zip file which includes the 'libs' folder" +
" and the file 'utils.py' inside of it. You will NOT be able"
" to complete this assignment unless you restart jupyter"
" notebook inside the directory created by extracting"
" the zip file or cloning the github repo.")
class ScanFile(object):
def __init__(self,directory,prefix=None,postfix='.jpg'):
self.directory=directory
self.prefix=prefix
self.postfix=postfix
def scan_files(self):
files_list=[]
for dirpath,dirnames,filenames in os.walk(self.directory):
'''''
dirpath is a string, the path to the directory.
dirnames is a list of the names of the subdirectories in dirpath (excluding '.' and '..').
filenames is a list of the names of the non-directory files in dirpath.
'''
for special_file in filenames:
if self.postfix:
special_file.endswith(self.postfix)
files_list.append(os.path.join(dirpath,special_file))
elif self.prefix:
special_file.startswith(self.prefix)
files_list.append(os.path.join(dirpath,special_file))
else:
files_list.append(os.path.join(dirpath,special_file))
return files_list
def scan_subdir(self):
subdir_list=[]
for dirpath,dirnames,files in os.walk(self.directory):
subdir_list.append(dirpath)
return subdir_list
# Get a list of jpg file (Only JPG works!)
image_dir = '/home/yida/Documents/buildboat/slic_superpixel/data/annotated_img'
scan1=ScanFile(image_dir)
files_img=scan1.scan_files()
object_dir = '/home/yida/Documents/buildboat/slic_superpixel/data/annotated_obj'
scan2=ScanFile(object_dir)
files_obj=scan2.scan_files()
assert len(files_obj) == len(files_img)
print('Files assertion passed, ', len(files_img), 'training files in total')
input_shape = [100, 100, 3]
# files_img = [os.path.join(image_dir, file_i) for file_i in os.listdir(image_dir) if file_i.endswith('.jpg')]
# files_obj = [os.path.join(object_dir, file_i) for file_i in os.listdir(object_dir) if file_i.endswith('.jpg')]
# Train it! Change these parameters!
tf.reset_default_graph()
vae.train_vae(files_img,
files_obj,
input_shape,
learning_rate=0.0001,
batch_size=225,
n_epochs=50,
n_examples=10,
crop_shape=[81, 81, 3],
crop_factor=0.9,
n_filters=[100, 100, 100, 100],
n_hidden=256,
n_code=50,
convolutional=True,
variational=True,
filter_sizes=[3, 3, 3, 3],
dropout=True,
keep_prob=0.8,
activation=tf.nn.relu,
img_step=100,
save_step=100,
ckpt_name="vae.ckpt")