Skip to content

Commit

Permalink
Add DynamicEmbedding to Keras
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 556908566
  • Loading branch information
divyashreepathihalli authored and tensorflower-gardener committed Aug 18, 2023
1 parent 31fb21f commit 5f9359e
Show file tree
Hide file tree
Showing 15 changed files with 1,748 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ py_library(
"//keras/protobuf:projector_config_proto_py_pb2",
"//keras/utils:engine_utils",
"//keras/utils:mode_keys",
"//keras/utils:timed_threads",
],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
path: "tensorflow.keras.callbacks.UpdateEmbeddingCallback"
tf_class {
is_instance: "<class \'keras.callbacks.UpdateEmbeddingCallback\'>"
is_instance: "<class \'keras.utils.timed_threads.TimedThread\'>"
is_instance: "<class \'keras.callbacks.Callback\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dynamic_embedding_layer\', \'interval\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_alive"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "on_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_interval"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "on_predict_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_begin"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_end"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_begin"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_end"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_begin"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_end"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "set_model"
argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_params"
argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "start"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "stop"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
4 changes: 4 additions & 0 deletions keras/api/golden/v1/tensorflow.keras.callbacks.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,8 @@ tf_module {
name: "TerminateOnNaN"
mtype: "<type \'type\'>"
}
member {
name: "UpdateEmbeddingCallback"
mtype: "<type \'type\'>"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
path: "tensorflow.keras.callbacks.UpdateEmbeddingCallback"
tf_class {
is_instance: "<class \'keras.callbacks.UpdateEmbeddingCallback\'>"
is_instance: "<class \'keras.utils.timed_threads.TimedThread\'>"
is_instance: "<class \'keras.callbacks.Callback\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dynamic_embedding_layer\', \'interval\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_alive"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "on_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_interval"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "on_predict_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_begin"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_end"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_begin"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_end"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_begin"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_end"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "set_model"
argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_params"
argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "start"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "stop"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
4 changes: 4 additions & 0 deletions keras/api/golden/v2/tensorflow.keras.callbacks.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ tf_module {
name: "TerminateOnNaN"
mtype: "<type \'type\'>"
}
member {
name: "UpdateEmbeddingCallback"
mtype: "<type \'type\'>"
}
member {
name: "experimental"
mtype: "<type \'module\'>"
Expand Down
115 changes: 115 additions & 0 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from keras.utils.data_utils import Sequence
from keras.utils.generic_utils import Progbar
from keras.utils.mode_keys import ModeKeys
from keras.utils.timed_threads import TimedThread

# isort: off
from tensorflow.python.platform import tf_logging as logging
Expand Down Expand Up @@ -3306,3 +3307,117 @@ def __init__(
self.on_train_begin = on_train_begin
if on_train_end is not None:
self.on_train_end = on_train_end


@keras_export("keras.callbacks.UpdateEmbeddingCallback")
class UpdateEmbeddingCallback(TimedThread, Callback):
"""A callback to update the DynamicEmbedding layer at specific time interval.
Updating the embedding matrix would mean that the optimizer variables will be
reset in this callback and this could have potential side effects. This means
that any existing slot variables associated with the optimizer will likely be
discarded when the optimizer is rebuilt. This affects optimizers that rely on
states of optimizer slot variables.
Example:
```
# Generate dummy data
train_data = np.array([
['a', 'j', 'c', 'd', 'e'],
['a', 'h', 'i', 'j', 'b'],
['i', 'h', 'c', 'j', 'e'],
])
train_labels = np.array([0, 1, 2])
vocab = tf.constant(['a', 'b', 'c', 'd', 'e'])
eviction_policy = 'LFU'
# Define the model
model = tf.keras.models.Sequential([
DynamicEmbedding(
input_dim=5,
output_dim=2,
input_length=5,
eviction_policy=eviction_policy,
initial_vocabulary=vocab,
),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(3, activation='softmax'),
])
# Compile the model
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
)
# update the vocabulary every 1 second
update_embedding_callback = UpdateEmbeddingCallback(
model.layers[0], interval=1
)
with update_embedding_callback:
result = model.fit(
train_data,
train_labels,
epochs=100,
batch_size=1,
callbacks=[update_embedding_callback],
)
```
"""

def __init__(self, dynamic_embedding_layer, interval):
"""Initialize Timed Callback object.
Args:
dynamic_embedding_layer: The dynamic embedding
layer to be updated.
interval: the interval, in seconds, to wait between calls to the
thread function. The thread function here updates the embeddings matrix
and resets the optimizer states.
"""
self._epoch = 0
TimedThread.__init__(self, interval)
Callback.__init__(self)
self._dynamic_embedding_layer = dynamic_embedding_layer
self.strategy = tf.distribute.get_strategy()

def on_interval(self):
try:
critical_section = tf.CriticalSection()

# Using `tf.CriticalSection` when updating embeddings using timed thread
# can help ensure thread safety and prevent race conditions in the shared
# variables.
def execute_critical_section():
critical_section.execute(
lambda: self._dynamic_embedding_layer.update_embeddings( # pylint: disable=g-long-lambda
self.strategy
)
)

# update embeddings across all devices if distributed training is used
self.strategy.run(execute_critical_section)
# update optimizer variables across all devices if distributed training is
# used.
self.strategy.run(
lambda: self._reset_optimizer()
) # pylint: disable=unnecessary-lambda
except AttributeError:
logging.info(
"Time interval specified to the UpdateEmbeddingCallback may be too"
" small, please try increasing the value of `interval`."
)

def _reset_optimizer(self):
"""Resetting the optimizer variables.
Resetting the optimizer variables is necessary after updating the variable
in the layer. This ensures that the optimizer is working with a consistent
internal state. This helps to prevent unexpected behavior and can lead to
more stable and faster training of the model.
"""
for var in self.model.optimizer.variables():
if "dynamic_embedding" in var.name:
backend.set_value(var, backend.zeros_like(var))

def on_epoch_begin(self, epoch, logs=None):
self._epoch = epoch
1 change: 1 addition & 0 deletions keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from keras.layers.core.tf_op_layer import SlicingOpLambda
from keras.layers.core.tf_op_layer import TFOpLambda


# Locally-connected layers.
from keras.layers.locally_connected.locally_connected1d import (
LocallyConnected1D,
Expand Down
Loading

0 comments on commit 5f9359e

Please sign in to comment.