Skip to content

Commit

Permalink
Support for different radiuses for the deconvolution input tensor
Browse files Browse the repository at this point in the history
Also append radius to the checkpoint directory name.
  • Loading branch information
igv committed Sep 10, 2017
1 parent 286e6cf commit e0cf791
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 36 deletions.
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
flags.DEFINE_float("learning_rate", 1e-4, "The learning rate of the adam optimizer [1e-4]")
flags.DEFINE_integer("c_dim", 1, "Dimension of image color [1]")
flags.DEFINE_integer("scale", 2, "The size of scale factor for preprocessing input image [2]")
flags.DEFINE_integer("radius", 1, "Max radius of the deconvolution input tensor [1]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]")
flags.DEFINE_string("output_dir", "result", "Name of test output directory [result]")
flags.DEFINE_string("data_dir", "Train", "Name of data directory to train on [FastTrain]")
Expand Down
15 changes: 8 additions & 7 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ def __init__(self, sess, config):
self.is_grayscale = (self.c_dim == 1)
self.epoch = config.epoch
self.scale = config.scale
self.radius = config.radius
self.batch_size = config.batch_size
self.learning_rate = config.learning_rate
self.threads = config.threads
self.distort = config.distort
self.params = config.params

self.padding = 4
# Different image/label sub-sizes for different scaling factors x2, x3, x4
scale_factors = [[24, 40], [18, 42], [16, 48]]
scale_factors = [[20 + self.padding, 40], [14 + self.padding, 42], [12 + self.padding, 48]]
self.image_size, self.label_size = scale_factors[self.scale - 2]
# Testing uses different strides to ensure sub-images line up correctly
if not self.train:
Expand All @@ -48,8 +50,6 @@ def __init__(self, sess, config):
# Different model layer counts and filter sizes for FSRCNN vs FSRCNN-s (fast), (d, s, m) in paper
model_params = [[56, 12, 4], [32, 8, 1]]
self.model_params = model_params[self.fast]

self.deconv_radius = [3, 5, 7][self.scale - 2]

self.checkpoint_dir = config.checkpoint_dir
self.output_dir = config.output_dir
Expand Down Expand Up @@ -169,7 +169,8 @@ def model(self):
d, s, m = self.model_params

# Feature Extraction
self.weights['w1'] = tf.get_variable('w1', initializer=tf.random_normal([5, 5, 1, d], stddev=0.0378, dtype=tf.float32))
size = self.radius * 2 + 1
self.weights['w1'] = tf.get_variable('w1', initializer=tf.random_normal([size, size, 1, d], stddev=0.0378, dtype=tf.float32))
self.biases['b1'] = tf.get_variable('b1', initializer=tf.zeros([d]))
conv = self.prelu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1'], 1)

Expand All @@ -196,7 +197,7 @@ def model(self):
conv = self.prelu(tf.nn.conv2d(conv, expand_weights, strides=[1,1,1,1], padding='SAME') + expand_biases, m + 3)

# Deconvolution
deconv_size = self.deconv_radius * 2 + 1
deconv_size = self.radius * self.scale * 2 + 1
deconv_weights = tf.get_variable('w{}'.format(m + 4), initializer=tf.random_normal([deconv_size, deconv_size, 1, d], stddev=0.0001, dtype=tf.float32))
deconv_biases = tf.get_variable('b{}'.format(m + 4), initializer=tf.zeros([1]))
self.weights['w{}'.format(m + 4)], self.biases['b{}'.format(m + 4)] = deconv_weights, deconv_biases
Expand All @@ -220,7 +221,7 @@ def prelu(self, _x, i):
def save(self, checkpoint_dir, step):
model_name = "FSRCNN.model"
d, s, m = self.model_params
model_dir = "%s_%s_%s-%s-%s" % ("fsrcnn", self.label_size, d, s, m)
model_dir = "%s_%s_%s-%s-%s_%s" % ("fsrcnn", self.label_size, d, s, m, "r"+str(self.radius))
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)

if not os.path.exists(checkpoint_dir):
Expand All @@ -233,7 +234,7 @@ def save(self, checkpoint_dir, step):
def load(self, checkpoint_dir):
print(" [*] Reading checkpoints...")
d, s, m = self.model_params
model_dir = "%s_%s_%s-%s-%s" % ("fsrcnn", self.label_size, d, s, m)
model_dir = "%s_%s_%s-%s-%s_%s" % ("fsrcnn", self.label_size, d, s, m, "r"+str(self.radius))
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)

ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
Expand Down
17 changes: 10 additions & 7 deletions sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,29 @@

def main():
scale = 2
radius = [3, 5, 7][scale-2]
radius = 2
size = radius * scale * 2 + 1
d = 64 #size of the feature layer

if len(sys.argv) == 2:
fname=sys.argv[1]
with open(fname) as f:
content = f.readlines()
content = [x.strip() for x in content]

x=list(reversed(range(scale)))
x=x[-1:]+x[:-1]
xy = []
for i in range(0, scale):
for j in range(0, scale):
for i in x:
for j in x:
xy.append([j, i])
xy = list(reversed(xy))

m = []
for i in range(0, len(xy)):
xi, yi = xy[i]
for x in range(xi, radius*2+1, scale):
for y in range(yi, radius*2+1, scale):
m.append(y + x*(radius*2+1))
for y in range(yi, size, scale):
for x in range(xi, size, scale):
m.append(y + x * size)
#print(m)
content = list(reversed(content))
sort = [content[m[l]].strip(",") for l in range(0, len(m))]
Expand Down
38 changes: 16 additions & 22 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,9 @@ def modcrop(image, scale=3):

def train_input_worker(args):
image_data, config = args
image_size, label_size, stride, scale, distort = config
image_size, label_size, stride, scale, padding, distort = config

single_input_sequence, single_label_sequence = [], []
padding = abs(image_size - label_size) // 2 # eg. for 3x: (21 - 11) / 2 = 5
label_padding = abs((image_size - 4) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7

input_, label_ = preprocess(image_data, scale, distort=distort)

Expand All @@ -133,10 +131,10 @@ def train_input_worker(args):
else:
h, w = input_.shape

for x in range(0, h - image_size - padding + 1, stride):
for y in range(0, w - image_size - padding + 1, stride):
sub_input = input_[x + padding : x + padding + image_size, y + padding : y + padding + image_size]
x_loc, y_loc = x + label_padding, y + label_padding
for x in range(0, h - image_size + 1, stride):
for y in range(0, w - image_size + 1, stride):
sub_input = input_[x : x + image_size, y : y + image_size]
x_loc, y_loc = x + padding, y + padding
sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size]

sub_input = sub_input.reshape([image_size, image_size, 1])
Expand Down Expand Up @@ -165,7 +163,7 @@ def thread_train_setup(config):
pool = Pool(config.threads)

# Distribute |images_per_thread| images across each worker process
config_values = [config.image_size, config.label_size, config.stride, config.scale, config.distort]
config_values = [config.image_size, config.label_size, config.stride, config.scale, config.padding // 2, config.distort]
images_per_thread = len(data) // config.threads
workers = []
for thread in range(config.threads):
Expand Down Expand Up @@ -202,14 +200,12 @@ def train_input_setup(config):
Read image files, make their sub-images, and save them as a h5 file format.
"""
sess = config.sess
image_size, label_size, stride, scale = config.image_size, config.label_size, config.stride, config.scale
image_size, label_size, stride, scale, padding = config.image_size, config.label_size, config.stride, config.scale, config.padding // 2

# Load data path
data = prepare_data(sess, dataset=config.data_dir)

sub_input_sequence, sub_label_sequence = [], []
padding = abs(image_size - label_size) // 2 # eg. for 3x: (21 - 11) / 2 = 5
label_padding = abs((image_size - 4) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7

for i in range(len(data)):
input_, label_ = preprocess(data[i], scale, distort=config.distort)
Expand All @@ -219,10 +215,10 @@ def train_input_setup(config):
else:
h, w = input_.shape

for x in range(0, h - image_size - padding + 1, stride):
for y in range(0, w - image_size - padding + 1, stride):
sub_input = input_[x + padding : x + padding + image_size, y + padding : y + padding + image_size]
x_loc, y_loc = x + label_padding, y + label_padding
for x in range(0, h - image_size + 1, stride):
for y in range(0, w - image_size + 1, stride):
sub_input = input_[x : x + image_size, y : y + image_size]
x_loc, y_loc = x + padding, y + padding
sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size]

sub_input = sub_input.reshape([image_size, image_size, 1])
Expand All @@ -242,14 +238,12 @@ def test_input_setup(config):
Read image files, make their sub-images, and save them as a h5 file format.
"""
sess = config.sess
image_size, label_size, stride, scale = config.image_size, config.label_size, config.stride, config.scale
image_size, label_size, stride, scale, padding = config.image_size, config.label_size, config.stride, config.scale, config.padding // 2

# Load data path
data = prepare_data(sess, dataset="Test")

sub_input_sequence, sub_label_sequence = [], []
padding = abs(image_size - label_size) // 2 # eg. (21 - 11) / 2 = 5
label_padding = abs((image_size - 4) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7

pic_index = 2 # Index of image based on lexicographic order in data folder
input_, label_ = preprocess(data[pic_index], config.scale)
Expand All @@ -260,13 +254,13 @@ def test_input_setup(config):
h, w = input_.shape

nx, ny = 0, 0
for x in range(0, h - image_size - padding + 1, stride):
for x in range(0, h - image_size + 1, stride):
nx += 1
ny = 0
for y in range(0, w - image_size - padding + 1, stride):
for y in range(0, w - image_size + 1, stride):
ny += 1
sub_input = input_[x + padding : x + padding + image_size, y + padding : y + padding + image_size]
x_loc, y_loc = x + label_padding, y + label_padding
sub_input = input_[x : x + image_size, y : y + image_size]
x_loc, y_loc = x + padding, y + padding
sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size]

sub_input = sub_input.reshape([image_size, image_size, 1])
Expand Down

0 comments on commit e0cf791

Please sign in to comment.