Skip to content

Commit

Permalink
[Embedding] Check the sharded property of tf.train.Saver.
Browse files Browse the repository at this point in the history
Signed-off-by: chenbangduo.cbd <[email protected]>
  • Loading branch information
JackMoriarty committed May 21, 2024
1 parent 93c69ad commit 7e11183
Show file tree
Hide file tree
Showing 17 changed files with 27 additions and 32 deletions.
3 changes: 1 addition & 2 deletions modelzoo/bst/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,10 +612,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dbmtl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dcn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,10 +594,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dcnv2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/deepfm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dien/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,10 +776,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/din/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,10 +594,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dlrm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,10 +507,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dssm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/esmm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,10 +534,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=train_steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/masknet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/mlperf/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/mmoe/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,10 +523,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/ple/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,10 +592,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/simple_multitask/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=train_steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/wide_and_deep/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,10 +543,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
11 changes: 11 additions & 0 deletions tensorflow/python/training/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,10 +1071,14 @@ def _build(self, checkpoint_path, build_save, build_restore):
# pylint: disable=protected-access
self._var_list = variables._all_saveable_objects()
from tensorflow.python.ops import hash_table
from tensorflow.python.ops import kv_variable_ops
if isinstance(self._var_list, dict):
ev = {}
ht = {}
lst = {}
for name, x in self._var_list.items():
if isinstance(x, kv_variable_ops.EmbeddingVariable):
ev[name] = x
if isinstance(x, hash_table.HashTable):
if x.hash_table not in ht:
ht[x.hash_table] = [x]
Expand All @@ -1084,15 +1088,20 @@ def _build(self, checkpoint_path, build_save, build_restore):
lst[name] = BloomFilterSaveable(x)
else:
lst[name] = x
if len(ev) != 0 and not self._sharded:
raise ValueError("EmbeddingVariable can only use sharded saver")
if len(ht) != 0 and not self._sharded:
raise ValueError("HashTable can only use sharded saver")
for x, y in ht.items():
lst[x.name] = HashTableSaveable(y)
self._var_list = lst
else:
ev = []
ht = {}
lst = []
for x in self._var_list:
if isinstance(x, kv_variable_ops.EmbeddingVariable):
ev.append(x)
if isinstance(x, hash_table.HashTable):
if x.hash_table not in ht:
ht[x.hash_table] = [x]
Expand All @@ -1102,6 +1111,8 @@ def _build(self, checkpoint_path, build_save, build_restore):
lst.append(BloomFilterSaveable(x))
else:
lst.append(x)
if len(ev) != 0 and not self._sharded:
raise ValueError("EmbeddingVariable can only use sharded saver")
if len(ht) != 0 and not self._sharded:
raise ValueError("HashTable can only use sharded saver")
for x, y in ht.items():
Expand Down

0 comments on commit 7e11183

Please sign in to comment.