Skip to content

Commit

Permalink
implemented saver using tensorflow
Browse files Browse the repository at this point in the history
  • Loading branch information
LordSomen committed Aug 15, 2018
1 parent c6c1577 commit 203d615
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion Tensorflow/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,21 @@ def fetch_batch(epoch, batch_index, batch_size):
for batch_index in range(n_batches):
X_batch, y_batch = fetch_batch(epoch, batch_index, batch_size)
sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
best_theta = theta.eval()
best_theta = theta.eval()

#%%
theta = tf.Variable(tf.random_uniform([n + 1, 1], -1.0, 1.0), name="theta")
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
for epoch in range(n_epochs):
if epoch % 100 == 0: # checkpoint every 100 epochs
save_path = saver.save(sess, "/tmp/my_model.ckpt")
sess.run(training_op)
best_theta = theta.eval()
save_path = saver.save(sess, "/tmp/my_model_final.ckpt")

#%%
with tf.Session() as sess:
saver.restore(sess, "/tmp/my_model_final.ckpt")

0 comments on commit 203d615

Please sign in to comment.