Skip to content

Commit

Permalink
[Embedding] Add dependencies of restore op.
Browse files Browse the repository at this point in the history
Signed-off-by: lixy9474 <[email protected]>
  • Loading branch information
lixy9474 committed Sep 20, 2023
1 parent 62d7e4b commit 8c1ed0b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
2 changes: 2 additions & 0 deletions tensorflow/python/framework/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6132,6 +6132,8 @@ class GraphKeys(object):
TRAINABLE_VARIABLES = "trainable_variables"
# Indicate EmbeddingVariable in CollectionDef
EMBEDDING_VARIABLES = "embedding_variables"
# Collection for dependencies of EmbeddingVariable's restore op
EMBEDDING_VARIABLE_RESTORE_DEPENDENCY = "embedding_variable_restore_dependency"
# Key to collect summaries.
SUMMARIES = "summaries"
# Key to collect QueueRunners.
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/python/ops/kv_variable_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ def _init_from_args(self,
self._slot_num = 0
else:
self._slot_num = evconfig.slot_num
if self._is_primary:
self._import_dependency_ops = []
with ops.name_scope("IsInitialized"):
self._is_initialized_op = (
gen_kv_variable_ops.kv_var_is_initialized_op(self._handle,
Expand Down Expand Up @@ -488,6 +490,7 @@ def create_init_op_for_restore(self, name, initial_value, invalid_key, rank):
set_attr_ops.append(set_cache_op)
with ops.control_dependencies(set_attr_ops + [self._initializer_for_restore]):
self._init_op_for_restore = control_flow_ops.no_op()
self.collect_restore_denpendencies()

def need_counts(self):
return (self._record_freq or (self._filter_freq > 0) or self._is_multi_tier)
Expand Down Expand Up @@ -612,8 +615,19 @@ def _init_from_proto(self, variable_def, import_scope=None):
else:
self._is_primary = False

self.collect_restore_denpendencies()
# LINT.ThenChange(//tensorflow/python/eager/graph_callable.py)

def collect_restore_denpendencies(self):
restore_dependency = ops.get_collection(ops.GraphKeys.EMBEDDING_VARIABLE_RESTORE_DEPENDENCY)
if len(restore_dependency) == 0:
ops.add_to_collection(ops.GraphKeys.EMBEDDING_VARIABLE_RESTORE_DEPENDENCY, {})
restore_dependency = ops.get_collection(ops.GraphKeys.EMBEDDING_VARIABLE_RESTORE_DEPENDENCY)
dependency_dict = restore_dependency[0]
if not dependency_dict.__contains__(self._primary_handle):
dependency_dict[self._primary_handle] = []
dependency_dict[self._primary_handle].append(self._init_op_for_restore)

def set_init_data_source_initializer(self, init_data_source):
import pkgutil
try:
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/python/training/saving/saveable_object_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def restore(self, restored_tensors, unused_restored_shapes):
if self.var._init_data_source is not None:
return self.var.recover_from_init_data_source(self.var._init_data_source, self.partition_id, self.partition_num)
else:
with ops.control_dependencies([self.var._init_op_for_restore]):
restore_dependency = ops.get_collection(ops.GraphKeys.EMBEDDING_VARIABLE_RESTORE_DEPENDENCY)[0]
with ops.control_dependencies(restore_dependency[self.var._primary_handle]):
rank = self.op.initial_value.get_shape().rank - 1
restore_op = gen_kv_variable_ops.kv_resource_import_v3(
restored_tensors[0],
Expand Down

0 comments on commit 8c1ed0b

Please sign in to comment.