Skip to content

Commit

Permalink
[Embedding] Check the sharded property of tf.train.Saver.
Browse files Browse the repository at this point in the history
Signed-off-by: chenbangduo.cbd <[email protected]>
  • Loading branch information
JackMoriarty committed May 22, 2024
1 parent 93c69ad commit d507996
Show file tree
Hide file tree
Showing 20 changed files with 65 additions and 64 deletions.
3 changes: 1 addition & 2 deletions modelzoo/bst/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,10 +612,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dbmtl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dcn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,10 +594,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dcnv2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/deepfm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dien/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,10 +776,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/din/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,10 +594,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dlrm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,10 +507,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dssm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/esmm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,10 +534,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=train_steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/masknet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/mlperf/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/mmoe/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,10 +523,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/ple/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,10 +592,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/simple_multitask/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=train_steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/wide_and_deep/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,10 +543,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/python/feature_column/feature_column_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7527,7 +7527,7 @@ def testEmbeddingVariableForL2FeatureEviction(self):
opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
saver = saver_module.Saver()
saver = saver_module.Saver(sharded=True)
init = variables_lib.global_variables_initializer()
with self.test_session() as sess:
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS))
Expand Down Expand Up @@ -7758,7 +7758,7 @@ def testEmbeddingVariableForSharedEmbeddingColumnsWithPartitionNum(self):
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
init = variables_lib.global_variables_initializer()
saver = saver_module.Saver()
saver = saver_module.Saver(sharded=True)

@test_util.run_deprecated_v1
def testEmbeddingVariableForInt32ID(self):
Expand All @@ -7783,7 +7783,7 @@ def testEmbeddingVariableForInt32ID(self):
opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
saver = saver_module.Saver()
saver = saver_module.Saver(sharded=True)
init = variables_lib.global_variables_initializer()
with self.test_session() as sess:
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS))
Expand Down
Loading

0 comments on commit d507996

Please sign in to comment.