Skip to content

Commit

Permalink
add parameters for SparseCore Embedding Config on v5p to run DLRM mod…
Browse files Browse the repository at this point in the history
…els on cloud v5p for tutorial

PiperOrigin-RevId: 650336609
  • Loading branch information
ZhaoyueCheng authored and tensorflower-gardener committed Jul 8, 2024
1 parent 8a7364c commit 7e541dd
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 48 deletions.
10 changes: 10 additions & 0 deletions official/recommendation/ranking/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ class ModelConfig(hyperparams.Config):
module
dcn_use_bias: Flag to determine whether to use bias for the dcn interaction
module
use_partial_tpu_embedding: Flag to determine whether to use partial tpu
embedding layer or not.
max_ids_per_chip_per_sample: Maximum number of ids per chip per sample.
max_ids_per_table: Maximum number of ids per table.
max_unique_ids_per_table: Maximum number of unique ids per table.
"""
num_dense_features: int = 13
vocab_sizes: List[int] = dataclasses.field(default_factory=list)
Expand All @@ -128,6 +133,10 @@ class ModelConfig(hyperparams.Config):
dcn_kernel_initializer: str = 'truncated_normal'
dcn_bias_initializer: str = 'zeros'
dcn_use_bias: bool = True
use_partial_tpu_embedding: bool = True
max_ids_per_chip_per_sample: int | None = None
max_ids_per_table: Union[int, List[int]] | None = None
max_unique_ids_per_table: Union[int, List[int]] | None = None


@dataclasses.dataclass
Expand Down Expand Up @@ -424,6 +433,7 @@ def dlrm_dcn_v2_criteo_tb_config() -> Config:
dcn_use_bias=True,
concat_dense=False,
use_multi_hot=True,
use_partial_tpu_embedding=False,
multi_hot_sizes=multi_hot_sizes,
),
loss=Loss(label_smoothing=0.0),
Expand Down
110 changes: 92 additions & 18 deletions official/recommendation/ranking/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Task for the Ranking model."""

import math
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Tuple

import tensorflow as tf, tf_keras
import tensorflow_recommenders as tfrs
Expand All @@ -35,8 +35,14 @@ def _get_tpu_embedding_feature_config(
vocab_sizes: List[int],
embedding_dim: Union[int, List[int]],
table_name_prefix: str = 'embedding_table',
batch_size: Optional[int] = None
) -> Dict[str, tf.tpu.experimental.embedding.FeatureConfig]:
batch_size: Optional[int] = None,
max_ids_per_chip_per_sample: Optional[int] = None,
max_ids_per_table: Optional[Union[int, List[int]]] = None,
max_unique_ids_per_table: Optional[Union[int, List[int]]] = None,
) -> Tuple[
Dict[str, tf.tpu.experimental.embedding.FeatureConfig],
Optional[tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig],
]:
"""Returns TPU embedding feature config.
i'th table config will have vocab size of vocab_sizes[i] and embedding
Expand All @@ -47,37 +53,97 @@ def _get_tpu_embedding_feature_config(
embedding_dim: An integer or a list of embedding table dimensions.
table_name_prefix: a prefix for embedding tables.
batch_size: Per-replica batch size.
max_ids_per_chip_per_sample: Maximum number of embedding ids per chip per
sample.
max_ids_per_table: Maximum number of embedding ids per table.
max_unique_ids_per_table: Maximum number of unique embedding ids per table.
Returns:
A dictionary of feature_name, FeatureConfig pairs.
"""
if isinstance(embedding_dim, List):
if len(vocab_sizes) != len(embedding_dim):
raise ValueError(
f'length of vocab_sizes: {len(vocab_sizes)} is not equal to the '
f'length of embedding_dim: {len(embedding_dim)}')
f'length of embedding_dim: {len(embedding_dim)}'
)
elif isinstance(embedding_dim, int):
embedding_dim = [embedding_dim] * len(vocab_sizes)
else:
raise ValueError('embedding_dim is not either a list or an int, got '
f'{type(embedding_dim)}')
raise ValueError(
'embedding_dim is not either a list or an int, got '
f'{type(embedding_dim)}'
)

if isinstance(max_ids_per_table, List):
if len(vocab_sizes) != len(max_ids_per_table):
raise ValueError(
f'length of vocab_sizes: {len(vocab_sizes)} is not equal to the '
f'length of max_ids_per_table: {len(max_ids_per_table)}'
)
elif isinstance(max_ids_per_table, int):
max_ids_per_table = [max_ids_per_table] * len(vocab_sizes)
elif max_ids_per_table is not None:
raise ValueError(
'max_ids_per_table is not either a list or an int or None, got '
f'{type(max_ids_per_table)}'
)

if isinstance(max_unique_ids_per_table, List):
if len(vocab_sizes) != len(max_unique_ids_per_table):
raise ValueError(
f'length of vocab_sizes: {len(vocab_sizes)} is not equal to the '
'length of max_unique_ids_per_table: '
f'{len(max_unique_ids_per_table)}'
)
elif isinstance(max_unique_ids_per_table, int):
max_unique_ids_per_table = [max_unique_ids_per_table] * len(vocab_sizes)
elif max_unique_ids_per_table is not None:
raise ValueError(
'max_unique_ids_per_table is not either a list or an int or None, '
f'got {type(max_unique_ids_per_table)}'
)

feature_config = {}
sparsecore_config = None
max_ids_per_table_dict = {}
max_unique_ids_per_table_dict = {}

for i, vocab_size in enumerate(vocab_sizes):
table_config = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=vocab_size,
dim=embedding_dim[i],
combiner='mean',
initializer=tf.initializers.TruncatedNormal(
mean=0.0, stddev=1 / math.sqrt(embedding_dim[i])),
name=table_name_prefix + '_%02d' % i)
mean=0.0, stddev=1 / math.sqrt(embedding_dim[i])
),
name=table_name_prefix + '_%02d' % i,
)
feature_config[str(i)] = tf.tpu.experimental.embedding.FeatureConfig(
name=str(i),
table=table_config,
output_shape=[batch_size] if batch_size else None,
)
if max_ids_per_table:
max_ids_per_table_dict[str(table_name_prefix + '_%02d' % i)] = (
max_ids_per_table[i]
)
if max_unique_ids_per_table:
max_unique_ids_per_table_dict[str(table_name_prefix + '_%02d' % i)] = (
max_unique_ids_per_table[i]
)

return feature_config
if all((max_ids_per_chip_per_sample, max_ids_per_table,
max_unique_ids_per_table)):
sparsecore_config = tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig(
disable_table_stacking=False,
max_ids_per_chip_per_sample=max_ids_per_chip_per_sample,
max_ids_per_table=max_ids_per_table_dict,
max_unique_ids_per_table=max_unique_ids_per_table_dict,
allow_id_dropping=False,
)

return feature_config, sparsecore_config


class RankingTask(base_task.Task):
Expand Down Expand Up @@ -173,25 +239,33 @@ def build_model(self) -> tf_keras.Model:
decay_start_steps=dense_lr_config.decay_start_steps)
dense_optimizer.learning_rate = dense_lr_callable

feature_config = _get_tpu_embedding_feature_config(
embedding_dim=self.task_config.model.embedding_dim,
vocab_sizes=self.task_config.model.vocab_sizes,
batch_size=self.task_config.train_data.global_batch_size
// tf.distribute.get_strategy().num_replicas_in_sync,
feature_config, sparse_core_embedding_config = (
_get_tpu_embedding_feature_config(
embedding_dim=self.task_config.model.embedding_dim,
vocab_sizes=self.task_config.model.vocab_sizes,
batch_size=self.task_config.train_data.global_batch_size
// tf.distribute.get_strategy().num_replicas_in_sync,
max_ids_per_chip_per_sample=self.task_config.model.max_ids_per_chip_per_sample,
max_ids_per_table=self.task_config.model.max_ids_per_table,
max_unique_ids_per_table=self.task_config.model.max_unique_ids_per_table,
)
)

if self.task_config.model.use_multi_hot:
embedding_layer = tfrs.layers.embedding.tpu_embedding_layer.TPUEmbedding(
# to work around PartialTPUEmbedding issue in v5p and to enable multi hot
# features
if self.task_config.model.use_partial_tpu_embedding:
embedding_layer = tfrs.experimental.layers.embedding.PartialTPUEmbedding(
feature_config=feature_config,
optimizer=embedding_optimizer,
pipeline_execution_with_tensor_core=self.trainer_config.pipeline_sparse_and_dense_execution,
size_threshold=self.task_config.model.size_threshold,
)
else:
embedding_layer = tfrs.experimental.layers.embedding.PartialTPUEmbedding(
embedding_layer = tfrs.layers.embedding.tpu_embedding_layer.TPUEmbedding(
feature_config=feature_config,
optimizer=embedding_optimizer,
pipeline_execution_with_tensor_core=self.trainer_config.pipeline_sparse_and_dense_execution,
size_threshold=self.task_config.model.size_threshold,
sparse_core_embedding_config=sparse_core_embedding_config,
)

if self.task_config.model.interaction == 'dot':
Expand Down
Loading

0 comments on commit 7e541dd

Please sign in to comment.