Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 663810871
  • Loading branch information
chandrasekhard2 authored and tensorflower-gardener committed Aug 16, 2024
1 parent e740e16 commit 9f5d9c0
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
2 changes: 2 additions & 0 deletions official/recommendation/ranking/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class ModelConfig(hyperparams.Config):
max_ids_per_chip_per_sample: int | None = None
max_ids_per_table: Union[int, List[int]] | None = None
max_unique_ids_per_table: Union[int, List[int]] | None = None
allow_id_dropping: bool = False
initialize_tables_on_host: bool = False


@dataclasses.dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,13 @@ def __init__(self,
num_dense_features: int,
vocab_sizes: List[int],
multi_hot_sizes: List[int],
use_synthetic_data: bool = False,
use_cached_data: bool = False):
use_synthetic_data: bool = False):
self._file_pattern = file_pattern
self._params = params
self._num_dense_features = num_dense_features
self._vocab_sizes = vocab_sizes
self._use_synthetic_data = use_synthetic_data
self._multi_hot_sizes = multi_hot_sizes
self._use_cached_data = use_cached_data

def __call__(self, ctx: tf.distribute.InputContext) -> tf.data.Dataset:
params = self._params
Expand Down Expand Up @@ -146,7 +144,7 @@ def make_dataset(shard_index):
num_parallel_calls=tf.data.experimental.AUTOTUNE)

dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
if self._use_cached_data:
if self._params.use_cached_data:
dataset = dataset.take(1).cache().repeat()

return dataset
Expand Down
11 changes: 10 additions & 1 deletion official/recommendation/ranking/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def _get_tpu_embedding_feature_config(
max_ids_per_chip_per_sample: Optional[int] = None,
max_ids_per_table: Optional[Union[int, List[int]]] = None,
max_unique_ids_per_table: Optional[Union[int, List[int]]] = None,
allow_id_dropping: bool = False,
initialize_tables_on_host: bool = False,
) -> Tuple[
Dict[str, tf.tpu.experimental.embedding.FeatureConfig],
Optional[tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig],
Expand All @@ -57,6 +59,10 @@ def _get_tpu_embedding_feature_config(
sample.
max_ids_per_table: Maximum number of embedding ids per table.
max_unique_ids_per_table: Maximum number of unique embedding ids per table.
allow_id_dropping: bool to allow id dropping.
initialize_tables_on_host: bool : if the embedding table size is more than
what HBM can handle, this flag will help initialize the full embedding
tables on host and then copy shards to HBM.
Returns:
A dictionary of feature_name, FeatureConfig pairs.
Expand Down Expand Up @@ -140,7 +146,8 @@ def _get_tpu_embedding_feature_config(
max_ids_per_chip_per_sample=max_ids_per_chip_per_sample,
max_ids_per_table=max_ids_per_table_dict,
max_unique_ids_per_table=max_unique_ids_per_table_dict,
allow_id_dropping=False,
allow_id_dropping=allow_id_dropping,
initialize_tables_on_host=initialize_tables_on_host,
)

return feature_config, sparsecore_config
Expand Down Expand Up @@ -248,6 +255,8 @@ def build_model(self) -> tf_keras.Model:
max_ids_per_chip_per_sample=self.task_config.model.max_ids_per_chip_per_sample,
max_ids_per_table=self.task_config.model.max_ids_per_table,
max_unique_ids_per_table=self.task_config.model.max_unique_ids_per_table,
allow_id_dropping=self.task_config.model.allow_id_dropping,
initialize_tables_on_host=self.task_config.model.initialize_tables_on_host,
)
)

Expand Down

0 comments on commit 9f5d9c0

Please sign in to comment.