Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Embedding] Fix set initialized flag too early in restore subgraph. #920

Merged
merged 1 commit into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tensorflow/core/framework/embedding/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,7 @@ enum ValuePosition {
IN_DRAM = 0;
NOT_IN_DRAM = 1;
}

enum IsSetInitialized {
NOT_SET_INITAILIZED = 0;
}
10 changes: 6 additions & 4 deletions tensorflow/core/framework/embedding/multi_tier_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,12 @@ class MultiTierStorage : public Storage<K, V> {
}

void InitCache(embedding::CacheStrategy cache_strategy) override {
cache_ = CacheFactory::Create<K>(cache_strategy, name_);
eviction_manager_ = EvictionManagerCreator::Create<K, V>();
eviction_manager_->AddStorage(this);
cache_thread_pool_ = CacheThreadPoolCreator::Create();
if (cache_ == nullptr) {
cache_ = CacheFactory::Create<K>(cache_strategy, name_);
eviction_manager_ = EvictionManagerCreator::Create<K, V>();
eviction_manager_->AddStorage(this);
cache_thread_pool_ = CacheThreadPoolCreator::Create();
}
}

Status BatchCommit(const std::vector<K>& keys,
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/framework/variable.proto
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ message VariableDef {

// EmebddingVariable
bool is_embedding_var = 91;

string initialize_op_for_restore = 92;
}

message SaveSliceInfoDef {
Expand Down
28 changes: 14 additions & 14 deletions tensorflow/core/kernels/kv_variable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,6 @@ limitations under the License.

namespace tensorflow {

namespace {
const int64 kEmbeddingVarUseDB = -214;
const int64 kInitializableEmbeddingVarUseDB = -215;
}

Status MoveMatchingFiles(
Env* env,
const tstring& pattern,
Expand Down Expand Up @@ -207,6 +202,15 @@ class InitializeKvVariableOp : public OpKernel {
(embedding_var_type ==
embedding::EmbeddingVariableType::IMMUTABLE);

//initial_num_buckets is useless, so is used to set is_set_initialized_.
int64 initial_num_buckets = 0;
OP_REQUIRES_OK(c, c->GetAttr("initial_num_buckets", &initial_num_buckets));
is_set_initialized_ = true;
if (initial_num_buckets ==
embedding::IsSetInitialized::NOT_SET_INITAILIZED) {
is_set_initialized_ = false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单句也加{ };写个if... else

}

int64 storage_type = 0;
OP_REQUIRES_OK(c, c->GetAttr("storage_type", &storage_type));
storage_type_ = static_cast<embedding::StorageType>(storage_type);
Expand Down Expand Up @@ -263,15 +267,10 @@ class InitializeKvVariableOp : public OpKernel {
" should be DRAM when layout is 'compact'."));
}

if (steps_to_live_ == kEmbeddingVarUseDB ||
steps_to_live_ == kInitializableEmbeddingVarUseDB) {
LOG(INFO) << "hashmap use db";
//use_db_ = true;
} else {
OP_REQUIRES(c, steps_to_live_ >= 0,
errors::InvalidArgument(
OP_REQUIRES(c, steps_to_live_ >= 0,
errors::InvalidArgument(
"steps_to_live must >= 0, ", std::to_string(steps_to_live_)));
}

OP_REQUIRES_OK(c, c->GetAttr("ht_type", &ht_type_));
if (embedding::StorageType::LEVELDB == storage_type_) {
ht_type_ = "leveldb_kv";
Expand Down Expand Up @@ -406,7 +405,7 @@ class InitializeKvVariableOp : public OpKernel {
core::ScopedUnref unref_me(primary_variable);
}
core::ScopedUnref unref_me(ev);
if (steps_to_live_ != kEmbeddingVarUseDB) {
if (is_set_initialized_) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个if 替换 改变原先逻辑

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以参考代码276行,可以保持一致。

ev->SetInitialized();
}
}
Expand Down Expand Up @@ -436,6 +435,7 @@ class InitializeKvVariableOp : public OpKernel {
bool record_freq_;
bool record_version_;
bool is_inference_;
bool is_set_initialized_;
};

#define REGISTER_KERNELS(ktype, vtype) \
Expand Down
65 changes: 65 additions & 0 deletions tensorflow/python/ops/embedding_variable_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2751,5 +2751,70 @@ def testCPUFbjOptWithBloomFilter(self):
self.assertNotEqual(val, 1.0)
del os.environ["TF_EMBEDDING_FBJ_OPT"]

def testSetInitializedWithoutRestore(self):
print("testSetInitializedWithoutRestore")
with ops.device("/cpu:0"):
var = variable_scope.get_embedding_variable("var_1",
embedding_dim = 3)
emb = embedding_ops.embedding_lookup(var, math_ops.cast([1], dtypes.int64))
fun = math_ops.multiply(emb, 2.0, name='multiply')
loss = math_ops.reduce_sum(fun, name='reduce_sum')
gs = training_util.get_or_create_global_step()
opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
init = variables.global_variables_initializer()
saver = saver_module.Saver()
with self.test_session() as sess:
result = sess.run(var._is_initialized_op)
self.assertEqual(False, result)
sess.run([init])
result = sess.run(var._is_initialized_op)
self.assertEqual(True, result)

def testSetInitializedWithRestore(self):
print("testSetInitializedWitRestore")
checkpoint_directory = self.get_temp_dir()
ckpt_path = os.path.join(checkpoint_directory, "model.ckpt")
with ops.Graph().as_default() as g, ops.device('/cpu:0'):
var = variable_scope.get_embedding_variable("var_1",
embedding_dim = 3)
emb = embedding_ops.embedding_lookup(var, math_ops.cast([1,2 ,3], dtypes.int64))
fun = math_ops.multiply(emb, 2.0, name='multiply')
loss = math_ops.reduce_sum(fun, name='reduce_sum')
gs = training_util.get_or_create_global_step()
opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
saver = saver_module.Saver()
init = variables.global_variables_initializer()
with self.test_session(graph=g) as sess:
sess.run([init])
sess.run(train_op)
saver.save(sess, ckpt_path)

with ops.Graph().as_default() as g, ops.device('/cpu:0'):
var = variable_scope.get_embedding_variable("var_1",
embedding_dim = 3)
emb = embedding_ops.embedding_lookup(var, math_ops.cast([1, 2, 3], dtypes.int64))
fun = math_ops.multiply(emb, 2.0, name='multiply')
loss = math_ops.reduce_sum(fun, name='reduce_sum')
gs = training_util.get_or_create_global_step()
opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
saver = saver_module.Saver()
init = variables.global_variables_initializer()
with self.test_session(graph=g) as sess:
result = sess.run(var._is_initialized_op)
self.assertEqual(False, result)
sess.run([var._initializer_for_restore])
result = sess.run(var._is_initialized_op)
self.assertEqual(False, result)

saver.restore(sess, ckpt_path)
result = sess.run(var._is_initialized_op)
self.assertEqual(True, result)

if __name__ == "__main__":
googletest.main()
52 changes: 52 additions & 0 deletions tensorflow/python/ops/kv_variable_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ def is_multi_tier(storage_type):
with ops.control_dependencies(set_attr_ops + [self._init_op]):
self._initializer_op = control_flow_ops.no_op()

self.create_init_op_for_restore(name, initial_value, invalid_key, rank)

self._graph_element = self._handle
self._cached_value = None
if not context.executing_eagerly():
Expand All @@ -444,6 +446,49 @@ def is_multi_tier(storage_type):
def export(self):
return gen_kv_variable_ops.kv_resource_export(self._handle, Tkeys=self._invalid_key_type)


def create_init_op_for_restore(self, name, initial_value, invalid_key, rank):
with ops.control_dependencies(None if self._is_primary else [self._primary._init_op_for_restore]):
self._initializer_for_restore = gen_kv_variable_ops.initialize_kv_variable_v2_op(
self._handle,
self._primary._handle,
variables._try_guard_against_uninitialized_dependencies(name, initial_value),
ops.convert_to_tensor(invalid_key),
initial_num_buckets=config_pb2.IsSetInitialized.NOT_SET_INITAILIZED,
slot_num = self._slot_num,
shape=initial_value.get_shape()[rank:],
steps_to_live=self._steps_to_live,
emb_index=self._emb_index, block_num=self.block_num,
slot_index=self._slot_index,
ht_type=self._ht_type,
ht_partition_num=self._ht_partition_num,
filter_freq = self._filter_freq,
l2_weight_threshold = self._l2_weight_threshold,
max_element_size = self._max_element_size,
false_positive_probability = self._false_positive_probability,
counter_type = self._counter_type,
max_freq = 99999,
layout = self._layout,
storage_type = self._storage_type,
storage_path = self._storage_path,
storage_size = self._storage_size,
default_value_dim = self._default_value_dim,
default_value_no_permission = self._default_value_no_permission,
record_freq = self._record_freq,
record_version = self._record_version,
embedding_variable_type=config_pb2.EmbeddingVariableType.IMMUTABLE)
set_attr_ops = []
if self._is_primary and self._is_multi_tier:
with ops.control_dependencies([self._initializer_for_restore]):
set_cache_op = gen_kv_variable_ops.kv_resource_init_cache_strategy_op(
self._handle,
cache_strategy=self._storage_cache_strategy,
Tkeys=self._invalid_key_type,
dtype=self._dtype)
set_attr_ops.append(set_cache_op)
with ops.control_dependencies(set_attr_ops + [self._initializer_for_restore]):
self._init_op_for_restore = control_flow_ops.no_op()

def need_counts(self):
return (self._record_freq or (self._filter_freq > 0) or self._is_multi_tier)
@property
Expand Down Expand Up @@ -482,6 +527,11 @@ def _init_from_proto(self, variable_def, import_scope=None):
cache_op = op
elif self._initializer_op.type == "InitializeKvVariableOp":
init_op = self._initializer_op

self._init_op_for_restore = g.as_graph_element(
ops.prepend_name_scope(
variable_def.initialize_op_for_restore,
import_scope=import_scope))
self._trainable = getattr(variable_def, "trainable", True)
if variable_def.snapshot_name:
self._cached_value = g.as_graph_element(
Expand Down Expand Up @@ -842,6 +892,8 @@ def to_proto(self, export_scope=None):
if self._save_slice_info:
var_def.save_slice_info_def.MergeFrom(
self._save_slice_info.to_proto(export_scope=export_scope))
var_def.initialize_op_for_restore = ops.strip_name_scope(
self._init_op_for_restore.name, export_scope)
return var_def
else:
return None
Expand Down
3 changes: 1 addition & 2 deletions tensorflow/python/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,7 @@ def _get_processor(v):
if v.op.type == "KvVarHandleOp":
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework.embedding import config_pb2
v._init_op._set_attr("embedding_variable_type",
attr_value_pb2.AttrValue(i=config_pb2.EmbeddingVariableType.MUTABLE))
slot_creator._set_init_op_embedding_type_attr(v, config_pb2.EmbeddingVariableType.MUTABLE)
return _DenseResourceVariableProcessor(v)
if isinstance(v, variables.Variable):
return _RefVariableProcessor(v)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/training/saving/saveable_object_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def restore(self, restored_tensors, unused_restored_shapes):
if self.var._init_data_source is not None:
return self.var.recover_from_init_data_source(self.var._init_data_source, self.partition_id, self.partition_num)
else:
with ops.control_dependencies([self.var._initializer_op]):
with ops.control_dependencies([self.var._init_op_for_restore]):
rank = self.op.initial_value.get_shape().rank - 1
restore_op = gen_kv_variable_ops.kv_resource_import_v3(
restored_tensors[0],
Expand Down
18 changes: 13 additions & 5 deletions tensorflow/python/training/slot_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype, slot_con
validate_shape=validate_shape,
steps_to_live=primary._steps_to_live,
ht_partition_num=primary._ht_partition_num)
slot._init_op._set_attr("embedding_variable_type",
attr_value_pb2.AttrValue(i=config_pb2.EmbeddingVariableType.MUTABLE))
_set_init_op_embedding_type_attr(slot, config_pb2.EmbeddingVariableType.MUTABLE)
else:
filter_strategy = None
if primary._filter_freq != 0:
Expand All @@ -107,7 +106,7 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype, slot_con
else:
filter_strategy = variables.CounterFilter(filter_freq=primary._filter_freq)
if slot_config.slot_type is config_pb2.SlotType.EMBEDDING_VARIABLE:
primary._init_op._set_attr("slot_num", attr_value_pb2.AttrValue(i=slot_config.slot_num))
_set_init_op_slot_num_attr(primary, slot_config.slot_num)
primary._slot_num = slot_config.slot_num
emb_index = primary._emb_index
if primary.block_num > 1:
Expand All @@ -132,8 +131,7 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype, slot_con
l2_weight_threshold=primary._l2_weight_threshold,
filter_strategy=filter_strategy)
)
slot._init_op._set_attr("embedding_variable_type",
attr_value_pb2.AttrValue(i=config_pb2.EmbeddingVariableType.MUTABLE))
_set_init_op_embedding_type_attr(slot, config_pb2.EmbeddingVariableType.MUTABLE)
else:
slot = variable_scope.get_variable(
scope,
Expand Down Expand Up @@ -300,3 +298,13 @@ def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True, slo
return create_slot(primary, val, name,
colocate_with_primary=colocate_with_primary,
slot_config=slot_config)

def _set_init_op_embedding_type_attr(var, embedding_type):
var._init_op._set_attr("embedding_variable_type",
attr_value_pb2.AttrValue(i=embedding_type))
var._initializer_for_restore._set_attr("embedding_variable_type",
attr_value_pb2.AttrValue(i=embedding_type))

def _set_init_op_slot_num_attr(var, slot_num):
var._init_op._set_attr("slot_num", attr_value_pb2.AttrValue(i=slot_num))
var._initializer_for_restore._set_attr("slot_num", attr_value_pb2.AttrValue(i=slot_num))
Loading