Skip to content

Commit

Permalink
fix some bugs reported at issur #8.
Browse files Browse the repository at this point in the history
  • Loading branch information
watsonyanghx committed Apr 23, 2018
1 parent 0de5931 commit f746546
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 20 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ CUDA_VISIBLE_DEVICES=0 python ./main.py --train_dir=../imgs/train/ \
--image_height=60 \
--image_width=180 \
--image_channel=1 \
--max_stepsize=64 \
--out_channels=64 \
--num_hidden=128 \
--batch_size=128 \
--log_dir=./log/train \
Expand Down
16 changes: 8 additions & 8 deletions cnn_lstm_otc_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, mode):
# SparseTensor required by ctc_loss op
self.labels = tf.sparse_placeholder(tf.int32)
# 1d array of size [batch_size]
self.seq_len = tf.placeholder(tf.int32, [None])
# self.seq_len = tf.placeholder(tf.int32, [None])
# l2
self._extra_train_ops = []

Expand All @@ -29,7 +29,7 @@ def build_graph(self):
self.merged_summay = tf.summary.merge_all()

def _build_model(self):
filters = [1, 64, 128, 128, FLAGS.max_stepsize]
filters = [1, 64, 128, 128, FLAGS.out_channels]
strides = [1, 2]

feature_h = FLAGS.image_height
Expand All @@ -54,15 +54,15 @@ def _build_model(self):

# print('----x.get_shape().as_list(): {}'.format(x.get_shape().as_list()))
_, feature_h, feature_w, _ = x.get_shape().as_list()
print('feature_h: {}, feature_w: {}'.format(feature_h, feature_w))
print('\nfeature_h: {}, feature_w: {}'.format(feature_h, feature_w))

# LSTM part
with tf.variable_scope('lstm'):
x = tf.reshape(x, [FLAGS.batch_size, -1, filters[4]]) # [batch_size, num_features, max_stepsize]
x = tf.transpose(x, [0, 2, 1]) # [batch_size, max_stepsize, num_features]
# shp = x.get_shape().as_list()
# x.set_shape([FLAGS.batch_size, filters[3], shp[1]])
x.set_shape([FLAGS.batch_size, filters[4], feature_h * feature_w])
x = tf.transpose(x, [0, 2, 1, 3]) # [batch_size, feature_w, feature_h, FLAGS.out_channels]
x = tf.reshape(x, [FLAGS.batch_size, feature_w, feature_h * FLAGS.out_channels])
print('lstm input shape: {}'.format(x.get_shape().as_list()))
self.seq_len = tf.fill([x.get_shape().as_list()[0]], feature_w)
# print('self.seq_len.shape: {}'.format(self.seq_len.shape.as_list()))

# tf.nn.rnn_cell.RNNCell, tf.nn.rnn_cell.GRUCell
cell = tf.nn.rnn_cell.LSTMCell(FLAGS.num_hidden, state_is_tuple=True)
Expand Down
15 changes: 6 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def train(train_dir=None, val_dir=None, mode='train'):

print('loading validation data')
val_feeder = utils.DataIterator(data_dir=val_dir)
print('size: ', val_feeder.size)
print('size: {}\n'.format(val_feeder.size))

num_train_samples = train_feeder.size # 100000
num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size) # example: 100000/100
Expand Down Expand Up @@ -69,12 +69,11 @@ def train(train_dir=None, val_dir=None, mode='train'):
batch_time = time.time()
indexs = [shuffle_idx[i % num_train_samples] for i in
range(cur_batch * FLAGS.batch_size, (cur_batch + 1) * FLAGS.batch_size)]
batch_inputs, batch_seq_len, batch_labels = \
batch_inputs, _, batch_labels = \
train_feeder.input_index_generate_batch(indexs)
# batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
feed = {model.inputs: batch_inputs,
model.labels: batch_labels,
model.seq_len: batch_seq_len}
model.labels: batch_labels}

# if summary is needed
# batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)
Expand Down Expand Up @@ -103,11 +102,10 @@ def train(train_dir=None, val_dir=None, mode='train'):
for j in range(num_batches_per_epoch_val):
indexs_val = [shuffle_idx_val[i % num_val_samples] for i in
range(j * FLAGS.batch_size, (j + 1) * FLAGS.batch_size)]
val_inputs, val_seq_len, val_labels = \
val_inputs, _, val_labels = \
val_feeder.input_index_generate_batch(indexs_val)
val_feed = {model.inputs: val_inputs,
model.labels: val_labels,
model.seq_len: val_seq_len}
model.labels: val_labels}

dense_decoded, lastbatch_err, lr = \
sess.run([model.dense_decoded, model.cost, model.lrn_rate],
Expand Down Expand Up @@ -177,8 +175,7 @@ def get_input_lens(seqs):
seq_len_input = np.asarray(seq_len_input)
seq_len_input = np.reshape(seq_len_input, [-1])

feed = {model.inputs: imgs_input,
model.seq_len: seq_len_input}
feed = {model.inputs: imgs_input}
dense_decoded_code = sess.run(model.dense_decoded, feed)

for item in dense_decoded_code:
Expand Down
4 changes: 2 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
tf.app.flags.DEFINE_integer('image_channel', 1, 'image channels as input')

tf.app.flags.DEFINE_integer('cnn_count', 4, 'count of cnn module to extract image features.')
tf.app.flags.DEFINE_integer('max_stepsize', 64,
tf.app.flags.DEFINE_integer('out_channels', 64,
'max stepsize in lstm, as well as the output channels of last layer in CNN')
tf.app.flags.DEFINE_integer('num_hidden', 128, 'number of hidden units in lstm')
tf.app.flags.DEFINE_float('output_keep_prob', 0.8, 'output_keep_prob in lstm')
Expand Down Expand Up @@ -101,7 +101,7 @@ def input_index_generate_batch(self, index=None):

def get_input_lens(sequences):
# 64 is the output channels of the last layer of CNN
lengths = np.asarray([FLAGS.max_stepsize for _ in sequences], dtype=np.int64)
lengths = np.asarray([FLAGS.out_channels for _ in sequences], dtype=np.int64)

return sequences, lengths

Expand Down

0 comments on commit f746546

Please sign in to comment.