diff --git a/modelzoo/bst/train.py b/modelzoo/bst/train.py index eeeb136678b..536ddbc6905 100644 --- a/modelzoo/bst/train.py +++ b/modelzoo/bst/train.py @@ -612,10 +612,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dbmtl/train.py b/modelzoo/dbmtl/train.py index c848cbc76b2..36f2685a175 100644 --- a/modelzoo/dbmtl/train.py +++ b/modelzoo/dbmtl/train.py @@ -527,10 +527,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dcn/train.py b/modelzoo/dcn/train.py index 44701e22d9f..5094a18bd85 100644 --- a/modelzoo/dcn/train.py +++ b/modelzoo/dcn/train.py @@ -594,10 +594,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dcnv2/train.py b/modelzoo/dcnv2/train.py index 5b572af0425..c1346ad6d7d 100644 --- a/modelzoo/dcnv2/train.py +++ b/modelzoo/dcnv2/train.py @@ -610,10 +610,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/deepfm/train.py b/modelzoo/deepfm/train.py index 166bedec0d0..89b2b823a46 100644 --- a/modelzoo/deepfm/train.py +++ b/modelzoo/deepfm/train.py @@ -472,10 +472,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dien/train.py b/modelzoo/dien/train.py index 190695f6ce0..f43fd2f1e73 100644 --- a/modelzoo/dien/train.py +++ b/modelzoo/dien/train.py @@ -776,10 +776,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/din/train.py b/modelzoo/din/train.py index 058583ce6fd..34621dee45e 100644 --- a/modelzoo/din/train.py +++ b/modelzoo/din/train.py @@ -594,10 +594,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dlrm/train.py b/modelzoo/dlrm/train.py index cc4c045c349..9dff32aca52 100644 --- a/modelzoo/dlrm/train.py +++ b/modelzoo/dlrm/train.py @@ -507,10 +507,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dssm/train.py b/modelzoo/dssm/train.py index db949aac5e8..9d2264d9ce9 100644 --- a/modelzoo/dssm/train.py +++ b/modelzoo/dssm/train.py @@ -478,10 +478,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/esmm/train.py b/modelzoo/esmm/train.py index 073b08814d4..1916ed76c27 100755 --- a/modelzoo/esmm/train.py +++ b/modelzoo/esmm/train.py @@ -534,10 +534,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=train_steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/masknet/train.py b/modelzoo/masknet/train.py index bb96a467701..bb9eee0ec3f 100644 --- a/modelzoo/masknet/train.py +++ b/modelzoo/masknet/train.py @@ -529,10 +529,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/mlperf/train.py b/modelzoo/mlperf/train.py index ce34fe5e55c..559e4fb6efc 100644 --- a/modelzoo/mlperf/train.py +++ b/modelzoo/mlperf/train.py @@ -522,10 +522,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/mmoe/train.py b/modelzoo/mmoe/train.py index 694eb45da80..a3a6c9146d8 100644 --- a/modelzoo/mmoe/train.py +++ b/modelzoo/mmoe/train.py @@ -523,10 +523,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/ple/train.py b/modelzoo/ple/train.py index b2d2f2057ec..33aa9a15e8e 100644 --- a/modelzoo/ple/train.py +++ b/modelzoo/ple/train.py @@ -592,10 +592,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/simple_multitask/train.py b/modelzoo/simple_multitask/train.py index 4ef1874a521..6eb51f7d4e9 100644 --- a/modelzoo/simple_multitask/train.py +++ b/modelzoo/simple_multitask/train.py @@ -427,10 +427,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=train_steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/wide_and_deep/train.py b/modelzoo/wide_and_deep/train.py index 3024f58024e..2d1c964e593 100644 --- a/modelzoo/wide_and_deep/train.py +++ b/modelzoo/wide_and_deep/train.py @@ -543,10 +543,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py index 7946aee1e1a..24f8a36daa4 100644 --- a/tensorflow/python/feature_column/feature_column_v2_test.py +++ b/tensorflow/python/feature_column/feature_column_v2_test.py @@ -7527,7 +7527,7 @@ def testEmbeddingVariableForL2FeatureEviction(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables_lib.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) @@ -7758,7 +7758,7 @@ def testEmbeddingVariableForSharedEmbeddingColumnsWithPartitionNum(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) init = variables_lib.global_variables_initializer() - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) @test_util.run_deprecated_v1 def testEmbeddingVariableForInt32ID(self): @@ -7783,7 +7783,7 @@ def testEmbeddingVariableForInt32ID(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables_lib.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) diff --git a/tensorflow/python/ops/embedding_variable_ops_gpu_test.py b/tensorflow/python/ops/embedding_variable_ops_gpu_test.py index d47d94d0d99..90bdf357355 100644 --- a/tensorflow/python/ops/embedding_variable_ops_gpu_test.py +++ b/tensorflow/python/ops/embedding_variable_ops_gpu_test.py @@ -748,7 +748,7 @@ def testSaveV3(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) init = variables.global_variables_initializer() - saver = saver = saver_module.Saver() + saver = saver = saver_module.Saver(sharded=True) checkpoint_directory = self.get_temp_dir() model_path = os.path.join(checkpoint_directory, "model.ckpt") with self.test_session() as sess: @@ -816,7 +816,7 @@ def testEmbeddingVariableSaveAndRestoreOptimzierStatesForMultiTierWithHbm(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) graph = ops.get_default_graph() with self.test_session(graph = graph) as sess: saver.restore(sess, os.path.join(checkpoint_directory, "model.ckpt-12345")) diff --git a/tensorflow/python/ops/embedding_variable_ops_test.py b/tensorflow/python/ops/embedding_variable_ops_test.py index dbf254d5f14..1119fd1c194 100644 --- a/tensorflow/python/ops/embedding_variable_ops_test.py +++ b/tensorflow/python/ops/embedding_variable_ops_test.py @@ -162,7 +162,7 @@ def _RecordFreqTestTemplate(self, optimizer): opt = self._CreateOptimizer(optimizer) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -194,7 +194,7 @@ def _RecordVersionTemplate(self, optimizer): opt = self._CreateOptimizer(optimizer) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -232,7 +232,7 @@ def testSaveVersionWithGlobalStepEviction(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) init = variables.global_variables_initializer() - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) model_path = os.path.join(checkpoint_directory, "model.ckpt") with self.test_session() as sess: sess.run([init]) @@ -269,7 +269,7 @@ def testFeatureColumnRecordFreqWithPartition(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -313,7 +313,7 @@ def testFeatureColumnRecordFreqSGDWithPartition(self): opt = gradient_descent.GradientDescentOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -387,7 +387,8 @@ def testDynamicEmbeddingVariableForInitFromProto(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) graph = ops.get_default_graph() - meta_graph_def = saver_module.export_meta_graph() + saver = saver_module.Saver(sharded=True) + meta_graph_def = saver_module.export_meta_graph(saver_def=saver.as_saver_def()) ops.reset_default_graph() with self.test_session() as sess: res = saver_module.import_meta_graph(meta_graph_def) @@ -406,7 +407,8 @@ def testEmbeddingVariableForInitFromProto(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) graph = ops.get_default_graph() - meta_graph_def = saver_module.export_meta_graph() + saver = saver_module.Saver(sharded=True) + meta_graph_def = saver_module.export_meta_graph(saver_def=saver.as_saver_def()) ops.reset_default_graph() with self.test_session() as sess: res = saver_module.import_meta_graph(meta_graph_def) @@ -450,7 +452,7 @@ def testEmbeddingVariableForLookupInt32(self): opt = adam.AdamOptimizer(0.01) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) @@ -643,7 +645,7 @@ def testEmbeddingVariableForL2FeatureEvictionFromContribFeatureColumn(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) @@ -682,7 +684,7 @@ def testEmbeddingVariableForGlobalStepEviction(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run([init]) @@ -720,7 +722,7 @@ def testEmbeddingVariableForL2FeatureEviction(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) @@ -1534,7 +1536,7 @@ def testEmbeddingVariableForSaveFreq(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) init = variables.global_variables_initializer() - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) model_path = os.path.join(checkpoint_directory, "model.ckpt") with self.test_session() as sess: sess.run([init]) @@ -1567,7 +1569,7 @@ def testEmbeddingVariableForL2FeatureEvictionDRAM(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) @@ -1724,7 +1726,7 @@ def runTestAdagrad(self, var, g): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -1778,7 +1780,7 @@ def runTestAdagrad(self, var, g): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -1849,7 +1851,7 @@ def runTestAdagrad(self, var, g): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -1923,7 +1925,7 @@ def testEmbeddingVariableForRecordFreq(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -1963,7 +1965,7 @@ def testEmbeddingVariableForRecordFreqWithCounterFilter(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -2278,7 +2280,7 @@ def testEmbeddingVariableForContirbFeatureColumnWithPartitionNum(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) def testSaveV3(self): print("testSaveV3") @@ -2295,7 +2297,7 @@ def testSaveV3(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) init = variables.global_variables_initializer() - saver = saver = saver_module.Saver() + saver = saver = saver_module.Saver(sharded=True) checkpoint_directory = self.get_temp_dir() model_path = os.path.join(checkpoint_directory, "model.ckpt") with self.test_session() as sess: @@ -2326,7 +2328,7 @@ def testEmbeddingVariableForNotSaveUnfilterFeature(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -2359,7 +2361,7 @@ def testEmbeddingVariableForSaveUnfilterFeature(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -2390,7 +2392,7 @@ def testEmbeddingVariableForMultiTierInference(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run([init]) @@ -2412,7 +2414,7 @@ def testEmbeddingVariableForMultiTierInference(self): emb = embedding_ops.embedding_lookup(emb_var, ids) tires = kv_variable_ops.lookup_tier(emb_var, math_ops.cast([1,2,3,4], dtypes.int64)) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) graph = ops.get_default_graph() with self.test_session(graph = graph) as sess: saver.restore(sess, os.path.join(checkpoint_directory, "model.ckpt")) @@ -2784,7 +2786,7 @@ def testSetInitializedWithoutRestore(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) init = variables.global_variables_initializer() - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) with self.test_session() as sess: result = sess.run(var._is_initialized_op) self.assertEqual(False, result) @@ -2806,7 +2808,7 @@ def testSetInitializedWithRestore(self): opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: sess.run([init]) @@ -2823,7 +2825,7 @@ def testSetInitializedWithRestore(self): opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: result = sess.run(var._is_initialized_op) @@ -2860,7 +2862,7 @@ def testCountsTensor(self): opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: sess.run([init]) @@ -2893,7 +2895,7 @@ def testCountsWithSparseAndDenseTensor(self): opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: sess.run([init]) @@ -2929,7 +2931,7 @@ def testCountsTensorWithGradientDescent(self): opt = gradient_descent.GradientDescentOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: sess.run([init]) @@ -2964,7 +2966,7 @@ def testCountsDenseAndSparseTensorWithGradientDescent(self): opt = gradient_descent.GradientDescentOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: sess.run([init]) diff --git a/tensorflow/python/training/incr_ckpt_test.py b/tensorflow/python/training/incr_ckpt_test.py index 55cf748a9d6..849c73a44dc 100644 --- a/tensorflow/python/training/incr_ckpt_test.py +++ b/tensorflow/python/training/incr_ckpt_test.py @@ -75,7 +75,7 @@ def testSparseEvIncrSaveRestore(self): emb = embedding_ops.embedding_lookup(var, math_ops.cast([0,1,2,5,6,7], dtypes.int64)) with ops.device("/device:CPU:0"): apply_incr = gen_io_ops.record_sparse_indices(math_ops.cast([0,1,2,5,6,7], dtypes.int64), "var_ev1") - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() ev_var_name = "var_ev1" incr_save_op = gen_io_ops.incr_save(incr_ckpt_path, [ev_var_name], [], [True],[var.handle]) @@ -178,7 +178,7 @@ def testMixIncrSaveRestore(self): activate_op = gen_io_ops. activate_sparse_recorder(["var_ev1","var_norm1"]) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() incr_save_op = gen_io_ops.incr_save(incr_ckpt_path, ["var_norm1", "var_ev1"], [], [True, True], [var_norm, var_ev.handle]) @@ -445,6 +445,7 @@ def testIncrementalSaverForResourceVariable(self): variable_scope.get_variable('var', shape=[100], use_resource=False) variable_scope.get_embedding_variable('ev', embedding_dim=100) saver = saver_module.Saver( + sharded=True, save_relative_paths=True, incremental_save_restore=True, ) diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index acc9723c183..e70226f2968 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -1071,10 +1071,14 @@ def _build(self, checkpoint_path, build_save, build_restore): # pylint: disable=protected-access self._var_list = variables._all_saveable_objects() from tensorflow.python.ops import hash_table + from tensorflow.python.ops import kv_variable_ops if isinstance(self._var_list, dict): + ev = {} ht = {} lst = {} for name, x in self._var_list.items(): + if isinstance(x, kv_variable_ops.EmbeddingVariable): + ev[name] = x if isinstance(x, hash_table.HashTable): if x.hash_table not in ht: ht[x.hash_table] = [x] @@ -1084,15 +1088,20 @@ def _build(self, checkpoint_path, build_save, build_restore): lst[name] = BloomFilterSaveable(x) else: lst[name] = x + if len(ev) != 0 and not self._sharded: + raise ValueError("EmbeddingVariable can only use sharded saver") if len(ht) != 0 and not self._sharded: raise ValueError("HashTable can only use sharded saver") for x, y in ht.items(): lst[x.name] = HashTableSaveable(y) self._var_list = lst else: + ev = [] ht = {} lst = [] for x in self._var_list: + if isinstance(x, kv_variable_ops.EmbeddingVariable): + ev.append(x) if isinstance(x, hash_table.HashTable): if x.hash_table not in ht: ht[x.hash_table] = [x] @@ -1102,6 +1111,8 @@ def _build(self, checkpoint_path, build_save, build_restore): lst.append(BloomFilterSaveable(x)) else: lst.append(x) + if len(ev) != 0 and not self._sharded: + raise ValueError("EmbeddingVariable can only use sharded saver") if len(ht) != 0 and not self._sharded: raise ValueError("HashTable can only use sharded saver") for x, y in ht.items(): diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index b48f00d0c14..365ef85af1d 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -852,6 +852,12 @@ def _model(): for orig, restored in zip(orig_vals, restored_vals): self.assertAllEqual(orig, restored) + def testEnableSaverShardedWhenUseEmbeddingVariable(self): + with ops_lib.Graph().as_default(): + emb_var = \ + variable_scope.get_embedding_variable(name="emb_var", embedding_dim=64) + with self.assertRaisesRegexp(ValueError, "EmbeddingVariable"): + saver_module.Saver([emb_var], sharded=False) class SaveRestoreShardedTest(test.TestCase):