diff --git a/tensorflow/core/framework/embedding/config.proto b/tensorflow/core/framework/embedding/config.proto index 3d5fae9f6ad..a8535347020 100644 --- a/tensorflow/core/framework/embedding/config.proto +++ b/tensorflow/core/framework/embedding/config.proto @@ -56,3 +56,7 @@ enum ValuePosition { IN_DRAM = 0; NOT_IN_DRAM = 1; } + +enum IsSetInitialized { + NOT_SET_INITAILIZED = 0; +} diff --git a/tensorflow/core/framework/embedding/multi_tier_storage.h b/tensorflow/core/framework/embedding/multi_tier_storage.h index ff18425ad9a..8239d109e64 100644 --- a/tensorflow/core/framework/embedding/multi_tier_storage.h +++ b/tensorflow/core/framework/embedding/multi_tier_storage.h @@ -81,10 +81,12 @@ class MultiTierStorage : public Storage { } void InitCache(embedding::CacheStrategy cache_strategy) override { - cache_ = CacheFactory::Create(cache_strategy, name_); - eviction_manager_ = EvictionManagerCreator::Create(); - eviction_manager_->AddStorage(this); - cache_thread_pool_ = CacheThreadPoolCreator::Create(); + if (cache_ == nullptr) { + cache_ = CacheFactory::Create(cache_strategy, name_); + eviction_manager_ = EvictionManagerCreator::Create(); + eviction_manager_->AddStorage(this); + cache_thread_pool_ = CacheThreadPoolCreator::Create(); + } } Status BatchCommit(const std::vector& keys, diff --git a/tensorflow/core/framework/variable.proto b/tensorflow/core/framework/variable.proto index 79ccd107628..5f9e0f16b5d 100644 --- a/tensorflow/core/framework/variable.proto +++ b/tensorflow/core/framework/variable.proto @@ -74,6 +74,8 @@ message VariableDef { // EmebddingVariable bool is_embedding_var = 91; + + string initialize_op_for_restore = 92; } message SaveSliceInfoDef { diff --git a/tensorflow/core/kernels/kv_variable_ops.cc b/tensorflow/core/kernels/kv_variable_ops.cc index 20ea6d3cb61..8a01a7bf2cd 100644 --- a/tensorflow/core/kernels/kv_variable_ops.cc +++ b/tensorflow/core/kernels/kv_variable_ops.cc @@ -43,11 +43,6 @@ limitations under the License. namespace tensorflow { -namespace { -const int64 kEmbeddingVarUseDB = -214; -const int64 kInitializableEmbeddingVarUseDB = -215; -} - Status MoveMatchingFiles( Env* env, const tstring& pattern, @@ -207,6 +202,15 @@ class InitializeKvVariableOp : public OpKernel { (embedding_var_type == embedding::EmbeddingVariableType::IMMUTABLE); + //initial_num_buckets is useless, so is used to set is_set_initialized_. + int64 initial_num_buckets = 0; + OP_REQUIRES_OK(c, c->GetAttr("initial_num_buckets", &initial_num_buckets)); + is_set_initialized_ = true; + if (initial_num_buckets == + embedding::IsSetInitialized::NOT_SET_INITAILIZED) { + is_set_initialized_ = false; + } + int64 storage_type = 0; OP_REQUIRES_OK(c, c->GetAttr("storage_type", &storage_type)); storage_type_ = static_cast(storage_type); @@ -263,15 +267,10 @@ class InitializeKvVariableOp : public OpKernel { " should be DRAM when layout is 'compact'.")); } - if (steps_to_live_ == kEmbeddingVarUseDB || - steps_to_live_ == kInitializableEmbeddingVarUseDB) { - LOG(INFO) << "hashmap use db"; - //use_db_ = true; - } else { - OP_REQUIRES(c, steps_to_live_ >= 0, - errors::InvalidArgument( + OP_REQUIRES(c, steps_to_live_ >= 0, + errors::InvalidArgument( "steps_to_live must >= 0, ", std::to_string(steps_to_live_))); - } + OP_REQUIRES_OK(c, c->GetAttr("ht_type", &ht_type_)); if (embedding::StorageType::LEVELDB == storage_type_) { ht_type_ = "leveldb_kv"; @@ -406,7 +405,7 @@ class InitializeKvVariableOp : public OpKernel { core::ScopedUnref unref_me(primary_variable); } core::ScopedUnref unref_me(ev); - if (steps_to_live_ != kEmbeddingVarUseDB) { + if (is_set_initialized_) { ev->SetInitialized(); } } @@ -436,6 +435,7 @@ class InitializeKvVariableOp : public OpKernel { bool record_freq_; bool record_version_; bool is_inference_; + bool is_set_initialized_; }; #define REGISTER_KERNELS(ktype, vtype) \ diff --git a/tensorflow/python/ops/embedding_variable_ops_test.py b/tensorflow/python/ops/embedding_variable_ops_test.py index d3e453df9d1..25a0cb6ff11 100644 --- a/tensorflow/python/ops/embedding_variable_ops_test.py +++ b/tensorflow/python/ops/embedding_variable_ops_test.py @@ -2751,5 +2751,70 @@ def testCPUFbjOptWithBloomFilter(self): self.assertNotEqual(val, 1.0) del os.environ["TF_EMBEDDING_FBJ_OPT"] + def testSetInitializedWithoutRestore(self): + print("testSetInitializedWithoutRestore") + with ops.device("/cpu:0"): + var = variable_scope.get_embedding_variable("var_1", + embedding_dim = 3) + emb = embedding_ops.embedding_lookup(var, math_ops.cast([1], dtypes.int64)) + fun = math_ops.multiply(emb, 2.0, name='multiply') + loss = math_ops.reduce_sum(fun, name='reduce_sum') + gs = training_util.get_or_create_global_step() + opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) + g_v = opt.compute_gradients(loss) + train_op = opt.apply_gradients(g_v) + init = variables.global_variables_initializer() + saver = saver_module.Saver() + with self.test_session() as sess: + result = sess.run(var._is_initialized_op) + self.assertEqual(False, result) + sess.run([init]) + result = sess.run(var._is_initialized_op) + self.assertEqual(True, result) + + def testSetInitializedWithRestore(self): + print("testSetInitializedWitRestore") + checkpoint_directory = self.get_temp_dir() + ckpt_path = os.path.join(checkpoint_directory, "model.ckpt") + with ops.Graph().as_default() as g, ops.device('/cpu:0'): + var = variable_scope.get_embedding_variable("var_1", + embedding_dim = 3) + emb = embedding_ops.embedding_lookup(var, math_ops.cast([1,2 ,3], dtypes.int64)) + fun = math_ops.multiply(emb, 2.0, name='multiply') + loss = math_ops.reduce_sum(fun, name='reduce_sum') + gs = training_util.get_or_create_global_step() + opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) + g_v = opt.compute_gradients(loss) + train_op = opt.apply_gradients(g_v) + saver = saver_module.Saver() + init = variables.global_variables_initializer() + with self.test_session(graph=g) as sess: + sess.run([init]) + sess.run(train_op) + saver.save(sess, ckpt_path) + + with ops.Graph().as_default() as g, ops.device('/cpu:0'): + var = variable_scope.get_embedding_variable("var_1", + embedding_dim = 3) + emb = embedding_ops.embedding_lookup(var, math_ops.cast([1, 2, 3], dtypes.int64)) + fun = math_ops.multiply(emb, 2.0, name='multiply') + loss = math_ops.reduce_sum(fun, name='reduce_sum') + gs = training_util.get_or_create_global_step() + opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) + g_v = opt.compute_gradients(loss) + train_op = opt.apply_gradients(g_v) + saver = saver_module.Saver() + init = variables.global_variables_initializer() + with self.test_session(graph=g) as sess: + result = sess.run(var._is_initialized_op) + self.assertEqual(False, result) + sess.run([var._initializer_for_restore]) + result = sess.run(var._is_initialized_op) + self.assertEqual(False, result) + + saver.restore(sess, ckpt_path) + result = sess.run(var._is_initialized_op) + self.assertEqual(True, result) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/kv_variable_ops.py b/tensorflow/python/ops/kv_variable_ops.py index e6140c9c149..701c03f6975 100644 --- a/tensorflow/python/ops/kv_variable_ops.py +++ b/tensorflow/python/ops/kv_variable_ops.py @@ -434,6 +434,8 @@ def is_multi_tier(storage_type): with ops.control_dependencies(set_attr_ops + [self._init_op]): self._initializer_op = control_flow_ops.no_op() + self.create_init_op_for_restore(name, initial_value, invalid_key, rank) + self._graph_element = self._handle self._cached_value = None if not context.executing_eagerly(): @@ -444,6 +446,49 @@ def is_multi_tier(storage_type): def export(self): return gen_kv_variable_ops.kv_resource_export(self._handle, Tkeys=self._invalid_key_type) + + def create_init_op_for_restore(self, name, initial_value, invalid_key, rank): + with ops.control_dependencies(None if self._is_primary else [self._primary._init_op_for_restore]): + self._initializer_for_restore = gen_kv_variable_ops.initialize_kv_variable_v2_op( + self._handle, + self._primary._handle, + variables._try_guard_against_uninitialized_dependencies(name, initial_value), + ops.convert_to_tensor(invalid_key), + initial_num_buckets=config_pb2.IsSetInitialized.NOT_SET_INITAILIZED, + slot_num = self._slot_num, + shape=initial_value.get_shape()[rank:], + steps_to_live=self._steps_to_live, + emb_index=self._emb_index, block_num=self.block_num, + slot_index=self._slot_index, + ht_type=self._ht_type, + ht_partition_num=self._ht_partition_num, + filter_freq = self._filter_freq, + l2_weight_threshold = self._l2_weight_threshold, + max_element_size = self._max_element_size, + false_positive_probability = self._false_positive_probability, + counter_type = self._counter_type, + max_freq = 99999, + layout = self._layout, + storage_type = self._storage_type, + storage_path = self._storage_path, + storage_size = self._storage_size, + default_value_dim = self._default_value_dim, + default_value_no_permission = self._default_value_no_permission, + record_freq = self._record_freq, + record_version = self._record_version, + embedding_variable_type=config_pb2.EmbeddingVariableType.IMMUTABLE) + set_attr_ops = [] + if self._is_primary and self._is_multi_tier: + with ops.control_dependencies([self._initializer_for_restore]): + set_cache_op = gen_kv_variable_ops.kv_resource_init_cache_strategy_op( + self._handle, + cache_strategy=self._storage_cache_strategy, + Tkeys=self._invalid_key_type, + dtype=self._dtype) + set_attr_ops.append(set_cache_op) + with ops.control_dependencies(set_attr_ops + [self._initializer_for_restore]): + self._init_op_for_restore = control_flow_ops.no_op() + def need_counts(self): return (self._record_freq or (self._filter_freq > 0) or self._is_multi_tier) @property @@ -482,6 +527,11 @@ def _init_from_proto(self, variable_def, import_scope=None): cache_op = op elif self._initializer_op.type == "InitializeKvVariableOp": init_op = self._initializer_op + + self._init_op_for_restore = g.as_graph_element( + ops.prepend_name_scope( + variable_def.initialize_op_for_restore, + import_scope=import_scope)) self._trainable = getattr(variable_def, "trainable", True) if variable_def.snapshot_name: self._cached_value = g.as_graph_element( @@ -842,6 +892,8 @@ def to_proto(self, export_scope=None): if self._save_slice_info: var_def.save_slice_info_def.MergeFrom( self._save_slice_info.to_proto(export_scope=export_scope)) + var_def.initialize_op_for_restore = ops.strip_name_scope( + self._init_op_for_restore.name, export_scope) return var_def else: return None diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 2b765814c0d..578d682cc11 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -243,8 +243,7 @@ def _get_processor(v): if v.op.type == "KvVarHandleOp": from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework.embedding import config_pb2 - v._init_op._set_attr("embedding_variable_type", - attr_value_pb2.AttrValue(i=config_pb2.EmbeddingVariableType.MUTABLE)) + slot_creator._set_init_op_embedding_type_attr(v, config_pb2.EmbeddingVariableType.MUTABLE) return _DenseResourceVariableProcessor(v) if isinstance(v, variables.Variable): return _RefVariableProcessor(v) diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py index cd3cba52676..0d8bfe87022 100644 --- a/tensorflow/python/training/saving/saveable_object_util.py +++ b/tensorflow/python/training/saving/saveable_object_util.py @@ -195,7 +195,7 @@ def restore(self, restored_tensors, unused_restored_shapes): if self.var._init_data_source is not None: return self.var.recover_from_init_data_source(self.var._init_data_source, self.partition_id, self.partition_num) else: - with ops.control_dependencies([self.var._initializer_op]): + with ops.control_dependencies([self.var._init_op_for_restore]): rank = self.op.initial_value.get_shape().rank - 1 restore_op = gen_kv_variable_ops.kv_resource_import_v3( restored_tensors[0], diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py index 90a820d82f6..6a359321c20 100644 --- a/tensorflow/python/training/slot_creator.py +++ b/tensorflow/python/training/slot_creator.py @@ -94,8 +94,7 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype, slot_con validate_shape=validate_shape, steps_to_live=primary._steps_to_live, ht_partition_num=primary._ht_partition_num) - slot._init_op._set_attr("embedding_variable_type", - attr_value_pb2.AttrValue(i=config_pb2.EmbeddingVariableType.MUTABLE)) + _set_init_op_embedding_type_attr(slot, config_pb2.EmbeddingVariableType.MUTABLE) else: filter_strategy = None if primary._filter_freq != 0: @@ -107,7 +106,7 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype, slot_con else: filter_strategy = variables.CounterFilter(filter_freq=primary._filter_freq) if slot_config.slot_type is config_pb2.SlotType.EMBEDDING_VARIABLE: - primary._init_op._set_attr("slot_num", attr_value_pb2.AttrValue(i=slot_config.slot_num)) + _set_init_op_slot_num_attr(primary, slot_config.slot_num) primary._slot_num = slot_config.slot_num emb_index = primary._emb_index if primary.block_num > 1: @@ -132,8 +131,7 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype, slot_con l2_weight_threshold=primary._l2_weight_threshold, filter_strategy=filter_strategy) ) - slot._init_op._set_attr("embedding_variable_type", - attr_value_pb2.AttrValue(i=config_pb2.EmbeddingVariableType.MUTABLE)) + _set_init_op_embedding_type_attr(slot, config_pb2.EmbeddingVariableType.MUTABLE) else: slot = variable_scope.get_variable( scope, @@ -300,3 +298,13 @@ def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True, slo return create_slot(primary, val, name, colocate_with_primary=colocate_with_primary, slot_config=slot_config) + +def _set_init_op_embedding_type_attr(var, embedding_type): + var._init_op._set_attr("embedding_variable_type", + attr_value_pb2.AttrValue(i=embedding_type)) + var._initializer_for_restore._set_attr("embedding_variable_type", + attr_value_pb2.AttrValue(i=embedding_type)) + +def _set_init_op_slot_num_attr(var, slot_num): + var._init_op._set_attr("slot_num", attr_value_pb2.AttrValue(i=slot_num)) + var._initializer_for_restore._set_attr("slot_num", attr_value_pb2.AttrValue(i=slot_num))