Skip to content

Commit

Permalink
Adjust deconvolution filter size for different scaling factors
Browse files Browse the repository at this point in the history
  • Loading branch information
igv committed Sep 1, 2017
1 parent f8af70e commit 5346bcf
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(self, sess, config):
# Different model layer counts and filter sizes for FSRCNN vs FSRCNN-s (fast), (d, s, m) in paper
model_params = [[56, 12, 4], [32, 8, 1]]
self.model_params = model_params[self.fast]

self.deconv_radius = [3, 5, 7][self.scale - 2]

self.checkpoint_dir = config.checkpoint_dir
self.output_dir = config.output_dir
Expand All @@ -65,11 +67,12 @@ def build_model(self):
d, s, m = self.model_params

expand_weight, deconv_weight = 'w{}'.format(m + 3), 'w{}'.format(m + 4)
deconv_size = self.deconv_radius * 2 + 1
self.weights = {
'w1': tf.Variable(tf.random_normal([5, 5, 1, d], stddev=0.0378, dtype=tf.float32), name='w1'),
'w2': tf.Variable(tf.random_normal([1, 1, d, s], stddev=0.3536, dtype=tf.float32), name='w2'),
expand_weight: tf.Variable(tf.random_normal([1, 1, s, d], stddev=0.189, dtype=tf.float32), name=expand_weight),
deconv_weight: tf.Variable(tf.random_normal([9, 9, 1, d], stddev=0.0001, dtype=tf.float32), name=deconv_weight)
deconv_weight: tf.Variable(tf.random_normal([deconv_size, deconv_size, 1, d], stddev=0.0001, dtype=tf.float32), name=deconv_weight)
}

expand_bias, deconv_bias = 'b{}'.format(m + 3), 'b{}'.format(m + 4)
Expand Down

0 comments on commit 5346bcf

Please sign in to comment.