From d50799675642ef1ac75d8c68094edd57d14d4b65 Mon Sep 17 00:00:00 2001 From: "chenbangduo.cbd" Date: Tue, 21 May 2024 20:31:10 +0800 Subject: [PATCH] [Embedding] Check the sharded property of tf.train.Saver. Signed-off-by: chenbangduo.cbd --- modelzoo/bst/train.py | 3 +- modelzoo/dbmtl/train.py | 3 +- modelzoo/dcn/train.py | 3 +- modelzoo/dcnv2/train.py | 3 +- modelzoo/deepfm/train.py | 3 +- modelzoo/dien/train.py | 3 +- modelzoo/din/train.py | 3 +- modelzoo/dlrm/train.py | 3 +- modelzoo/dssm/train.py | 3 +- modelzoo/esmm/train.py | 3 +- modelzoo/masknet/train.py | 3 +- modelzoo/mlperf/train.py | 3 +- modelzoo/mmoe/train.py | 3 +- modelzoo/ple/train.py | 3 +- modelzoo/simple_multitask/train.py | 3 +- modelzoo/wide_and_deep/train.py | 3 +- .../feature_column/feature_column_v2_test.py | 6 +- .../python/ops/embedding_variable_ops_test.py | 58 +++++++++---------- tensorflow/python/training/saver.py | 11 ++++ tensorflow/python/training/saver_test.py | 6 ++ 20 files changed, 65 insertions(+), 64 deletions(-) diff --git a/modelzoo/bst/train.py b/modelzoo/bst/train.py index eeeb136678b..536ddbc6905 100644 --- a/modelzoo/bst/train.py +++ b/modelzoo/bst/train.py @@ -612,10 +612,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dbmtl/train.py b/modelzoo/dbmtl/train.py index c848cbc76b2..36f2685a175 100644 --- a/modelzoo/dbmtl/train.py +++ b/modelzoo/dbmtl/train.py @@ -527,10 +527,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dcn/train.py b/modelzoo/dcn/train.py index 44701e22d9f..5094a18bd85 100644 --- a/modelzoo/dcn/train.py +++ b/modelzoo/dcn/train.py @@ -594,10 +594,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dcnv2/train.py b/modelzoo/dcnv2/train.py index 5b572af0425..c1346ad6d7d 100644 --- a/modelzoo/dcnv2/train.py +++ b/modelzoo/dcnv2/train.py @@ -610,10 +610,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/deepfm/train.py b/modelzoo/deepfm/train.py index 166bedec0d0..89b2b823a46 100644 --- a/modelzoo/deepfm/train.py +++ b/modelzoo/deepfm/train.py @@ -472,10 +472,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dien/train.py b/modelzoo/dien/train.py index 190695f6ce0..f43fd2f1e73 100644 --- a/modelzoo/dien/train.py +++ b/modelzoo/dien/train.py @@ -776,10 +776,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/din/train.py b/modelzoo/din/train.py index 058583ce6fd..34621dee45e 100644 --- a/modelzoo/din/train.py +++ b/modelzoo/din/train.py @@ -594,10 +594,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dlrm/train.py b/modelzoo/dlrm/train.py index cc4c045c349..9dff32aca52 100644 --- a/modelzoo/dlrm/train.py +++ b/modelzoo/dlrm/train.py @@ -507,10 +507,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dssm/train.py b/modelzoo/dssm/train.py index db949aac5e8..9d2264d9ce9 100644 --- a/modelzoo/dssm/train.py +++ b/modelzoo/dssm/train.py @@ -478,10 +478,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/esmm/train.py b/modelzoo/esmm/train.py index 073b08814d4..1916ed76c27 100755 --- a/modelzoo/esmm/train.py +++ b/modelzoo/esmm/train.py @@ -534,10 +534,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=train_steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/masknet/train.py b/modelzoo/masknet/train.py index bb96a467701..bb9eee0ec3f 100644 --- a/modelzoo/masknet/train.py +++ b/modelzoo/masknet/train.py @@ -529,10 +529,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/mlperf/train.py b/modelzoo/mlperf/train.py index ce34fe5e55c..559e4fb6efc 100644 --- a/modelzoo/mlperf/train.py +++ b/modelzoo/mlperf/train.py @@ -522,10 +522,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/mmoe/train.py b/modelzoo/mmoe/train.py index 694eb45da80..a3a6c9146d8 100644 --- a/modelzoo/mmoe/train.py +++ b/modelzoo/mmoe/train.py @@ -523,10 +523,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/ple/train.py b/modelzoo/ple/train.py index b2d2f2057ec..33aa9a15e8e 100644 --- a/modelzoo/ple/train.py +++ b/modelzoo/ple/train.py @@ -592,10 +592,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/simple_multitask/train.py b/modelzoo/simple_multitask/train.py index 4ef1874a521..6eb51f7d4e9 100644 --- a/modelzoo/simple_multitask/train.py +++ b/modelzoo/simple_multitask/train.py @@ -427,10 +427,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=train_steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/wide_and_deep/train.py b/modelzoo/wide_and_deep/train.py index 3024f58024e..2d1c964e593 100644 --- a/modelzoo/wide_and_deep/train.py +++ b/modelzoo/wide_and_deep/train.py @@ -543,10 +543,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py index 7946aee1e1a..24f8a36daa4 100644 --- a/tensorflow/python/feature_column/feature_column_v2_test.py +++ b/tensorflow/python/feature_column/feature_column_v2_test.py @@ -7527,7 +7527,7 @@ def testEmbeddingVariableForL2FeatureEviction(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables_lib.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) @@ -7758,7 +7758,7 @@ def testEmbeddingVariableForSharedEmbeddingColumnsWithPartitionNum(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) init = variables_lib.global_variables_initializer() - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) @test_util.run_deprecated_v1 def testEmbeddingVariableForInt32ID(self): @@ -7783,7 +7783,7 @@ def testEmbeddingVariableForInt32ID(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables_lib.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) diff --git a/tensorflow/python/ops/embedding_variable_ops_test.py b/tensorflow/python/ops/embedding_variable_ops_test.py index dbf254d5f14..664d62e5abb 100644 --- a/tensorflow/python/ops/embedding_variable_ops_test.py +++ b/tensorflow/python/ops/embedding_variable_ops_test.py @@ -162,7 +162,7 @@ def _RecordFreqTestTemplate(self, optimizer): opt = self._CreateOptimizer(optimizer) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -194,7 +194,7 @@ def _RecordVersionTemplate(self, optimizer): opt = self._CreateOptimizer(optimizer) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -232,7 +232,7 @@ def testSaveVersionWithGlobalStepEviction(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) init = variables.global_variables_initializer() - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) model_path = os.path.join(checkpoint_directory, "model.ckpt") with self.test_session() as sess: sess.run([init]) @@ -269,7 +269,7 @@ def testFeatureColumnRecordFreqWithPartition(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -313,7 +313,7 @@ def testFeatureColumnRecordFreqSGDWithPartition(self): opt = gradient_descent.GradientDescentOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -450,7 +450,7 @@ def testEmbeddingVariableForLookupInt32(self): opt = adam.AdamOptimizer(0.01) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) @@ -643,7 +643,7 @@ def testEmbeddingVariableForL2FeatureEvictionFromContribFeatureColumn(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) @@ -682,7 +682,7 @@ def testEmbeddingVariableForGlobalStepEviction(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run([init]) @@ -720,7 +720,7 @@ def testEmbeddingVariableForL2FeatureEviction(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) @@ -1534,7 +1534,7 @@ def testEmbeddingVariableForSaveFreq(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) init = variables.global_variables_initializer() - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) model_path = os.path.join(checkpoint_directory, "model.ckpt") with self.test_session() as sess: sess.run([init]) @@ -1567,7 +1567,7 @@ def testEmbeddingVariableForL2FeatureEvictionDRAM(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) @@ -1724,7 +1724,7 @@ def runTestAdagrad(self, var, g): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -1778,7 +1778,7 @@ def runTestAdagrad(self, var, g): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -1849,7 +1849,7 @@ def runTestAdagrad(self, var, g): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -1923,7 +1923,7 @@ def testEmbeddingVariableForRecordFreq(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -1963,7 +1963,7 @@ def testEmbeddingVariableForRecordFreqWithCounterFilter(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -2278,7 +2278,7 @@ def testEmbeddingVariableForContirbFeatureColumnWithPartitionNum(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) def testSaveV3(self): print("testSaveV3") @@ -2295,7 +2295,7 @@ def testSaveV3(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) init = variables.global_variables_initializer() - saver = saver = saver_module.Saver() + saver = saver = saver_module.Saver(sharded=True) checkpoint_directory = self.get_temp_dir() model_path = os.path.join(checkpoint_directory, "model.ckpt") with self.test_session() as sess: @@ -2326,7 +2326,7 @@ def testEmbeddingVariableForNotSaveUnfilterFeature(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -2359,7 +2359,7 @@ def testEmbeddingVariableForSaveUnfilterFeature(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -2390,7 +2390,7 @@ def testEmbeddingVariableForMultiTierInference(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run([init]) @@ -2412,7 +2412,7 @@ def testEmbeddingVariableForMultiTierInference(self): emb = embedding_ops.embedding_lookup(emb_var, ids) tires = kv_variable_ops.lookup_tier(emb_var, math_ops.cast([1,2,3,4], dtypes.int64)) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) graph = ops.get_default_graph() with self.test_session(graph = graph) as sess: saver.restore(sess, os.path.join(checkpoint_directory, "model.ckpt")) @@ -2784,7 +2784,7 @@ def testSetInitializedWithoutRestore(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) init = variables.global_variables_initializer() - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) with self.test_session() as sess: result = sess.run(var._is_initialized_op) self.assertEqual(False, result) @@ -2806,7 +2806,7 @@ def testSetInitializedWithRestore(self): opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: sess.run([init]) @@ -2823,7 +2823,7 @@ def testSetInitializedWithRestore(self): opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: result = sess.run(var._is_initialized_op) @@ -2860,7 +2860,7 @@ def testCountsTensor(self): opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: sess.run([init]) @@ -2893,7 +2893,7 @@ def testCountsWithSparseAndDenseTensor(self): opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: sess.run([init]) @@ -2929,7 +2929,7 @@ def testCountsTensorWithGradientDescent(self): opt = gradient_descent.GradientDescentOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: sess.run([init]) @@ -2964,7 +2964,7 @@ def testCountsDenseAndSparseTensorWithGradientDescent(self): opt = gradient_descent.GradientDescentOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: sess.run([init]) diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index acc9723c183..e70226f2968 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -1071,10 +1071,14 @@ def _build(self, checkpoint_path, build_save, build_restore): # pylint: disable=protected-access self._var_list = variables._all_saveable_objects() from tensorflow.python.ops import hash_table + from tensorflow.python.ops import kv_variable_ops if isinstance(self._var_list, dict): + ev = {} ht = {} lst = {} for name, x in self._var_list.items(): + if isinstance(x, kv_variable_ops.EmbeddingVariable): + ev[name] = x if isinstance(x, hash_table.HashTable): if x.hash_table not in ht: ht[x.hash_table] = [x] @@ -1084,15 +1088,20 @@ def _build(self, checkpoint_path, build_save, build_restore): lst[name] = BloomFilterSaveable(x) else: lst[name] = x + if len(ev) != 0 and not self._sharded: + raise ValueError("EmbeddingVariable can only use sharded saver") if len(ht) != 0 and not self._sharded: raise ValueError("HashTable can only use sharded saver") for x, y in ht.items(): lst[x.name] = HashTableSaveable(y) self._var_list = lst else: + ev = [] ht = {} lst = [] for x in self._var_list: + if isinstance(x, kv_variable_ops.EmbeddingVariable): + ev.append(x) if isinstance(x, hash_table.HashTable): if x.hash_table not in ht: ht[x.hash_table] = [x] @@ -1102,6 +1111,8 @@ def _build(self, checkpoint_path, build_save, build_restore): lst.append(BloomFilterSaveable(x)) else: lst.append(x) + if len(ev) != 0 and not self._sharded: + raise ValueError("EmbeddingVariable can only use sharded saver") if len(ht) != 0 and not self._sharded: raise ValueError("HashTable can only use sharded saver") for x, y in ht.items(): diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index b48f00d0c14..365ef85af1d 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -852,6 +852,12 @@ def _model(): for orig, restored in zip(orig_vals, restored_vals): self.assertAllEqual(orig, restored) + def testEnableSaverShardedWhenUseEmbeddingVariable(self): + with ops_lib.Graph().as_default(): + emb_var = \ + variable_scope.get_embedding_variable(name="emb_var", embedding_dim=64) + with self.assertRaisesRegexp(ValueError, "EmbeddingVariable"): + saver_module.Saver([emb_var], sharded=False) class SaveRestoreShardedTest(test.TestCase):