From b5877eb7b9837244f0637c87c27e36b0267b0f65 Mon Sep 17 00:00:00 2001 From: Dominik Date: Wed, 8 Feb 2017 23:58:54 -0500 Subject: [PATCH] changed dropout to placeholder, resolved #27 --- models/sentiment.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/models/sentiment.py b/models/sentiment.py index d695bee74..bc4e07512 100644 --- a/models/sentiment.py +++ b/models/sentiment.py @@ -20,6 +20,7 @@ def __init__(self, vocab_size, hidden_size, dropout, num_layers, max_gradient_norm, max_seq_length, learning_rate, lr_decay,batch_size, forward_only=False): self.num_classes =2 + self.dropout = dropout self.vocab_size = vocab_size self.learning_rate = tf.Variable(float(learning_rate), trainable=False) self.learning_rate_decay_op = self.learning_rate.assign( @@ -45,9 +46,12 @@ def __init__(self, vocab_size, hidden_size, dropout, self.seq_lengths = tf.placeholder(tf.int32, shape=[None], name="early_stop") - self.dropout_keep_prob_embedding = tf.constant(self.dropout) - self.dropout_keep_prob_lstm_input = tf.constant(self.dropout) - self.dropout_keep_prob_lstm_output = tf.constant(self.dropout) + self.dropout_keep_prob_embedding = tf.placeholder(tf.float32, + name="dropout_keep_prob_embedding") + self.dropout_keep_prob_lstm_input = tf.placeholder(tf.float32, + name="dropout_keep_prob_lstm_input") + self.dropout_keep_prob_lstm_output = tf.placeholder(tf.float32, + name="dropout_keep_prob_lstm_output") with tf.variable_scope("embedding"), tf.device("/cpu:0"): W = tf.get_variable( @@ -199,6 +203,10 @@ def step(self, session, inputs, targets, seq_lengths, forward_only=False): input_feed[self.seq_input.name] = inputs input_feed[self.target.name] = targets input_feed[self.seq_lengths.name] = seq_lengths + input_feed[self.dropout_keep_prob_embedding.name] = self.dropout + input_feed[self.dropout_keep_prob_lstm_input.name] = self.dropout + input_feed[self.dropout_keep_prob_lstm_output.name] = self.dropout + if not forward_only: input_feed[self.str_summary_type.name] = "train" output_feed = [self.merged, self.mean_loss, self.update]