Skip to content

Commit

Permalink
Making retrieval encoders more trustworthy by setting the map_partiti…
Browse files Browse the repository at this point in the history
…ons(meta) with the expected output dataframe schema. This fixes the issue when multi-hot features were used in the user / item tower encoding
  • Loading branch information
gabrielspmoreira committed Jun 30, 2023
1 parent b671c8e commit 33fb8c4
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 23 deletions.
6 changes: 4 additions & 2 deletions merlin/datasets/entertainment/music_streaming/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@
{
"name": "item_genres",
"valueCount": {
"min": "4"
"min": "4",
"max": "4"
},
"type": "INT",
"intDomain": {
Expand All @@ -98,7 +99,8 @@
"annotation": {
"tag": [
"categorical",
"user_id"
"user_id",
"user"
]
}
},
Expand Down
19 changes: 15 additions & 4 deletions merlin/models/tf/core/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from merlin.models.tf.outputs.topk import TopKOutput
from merlin.models.tf.transforms.features import PrepareFeatures
from merlin.models.tf.utils import tf_utils
from merlin.models.tf.utils.batch_utils import TFModelEncode
from merlin.schema import ColumnSchema, Schema, Tags


Expand Down Expand Up @@ -171,13 +172,18 @@ def batch_predict(
if hasattr(dataset, "to_ddf"):
dataset = dataset.to_ddf()

from merlin.models.tf.utils.batch_utils import TFModelEncode

model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs)

encode_kwargs = {}
if output_schema:
encode_kwargs["filter_input_columns"] = output_schema.column_names
predictions = dataset.map_partitions(model_encode, **encode_kwargs)

# Processing a sample of the dataset with the model encoder
# to get the output dataframe dtypes
sample_output = model_encode(dataset.head(), **encode_kwargs)
output_dtypes = sample_output.dtypes.to_dict()

predictions = dataset.map_partitions(model_encode, meta=output_dtypes, **encode_kwargs)
if index:
predictions = predictions.set_index(index)

Expand Down Expand Up @@ -613,7 +619,12 @@ def batch_predict(
if output_schema:
encode_kwargs["filter_input_columns"] = output_schema.column_names

predictions = dataset.map_partitions(model_encode, **encode_kwargs)
# Processing a sample of the dataset with the model encoder
# to get the output dataframe dtypes
sample_output = model_encode(dataset.head(), **encode_kwargs)
output_dtypes = sample_output.dtypes.to_dict()

predictions = dataset.map_partitions(model_encode, meta=output_dtypes, **encode_kwargs)

return merlin.io.Dataset(predictions)

Expand Down
10 changes: 9 additions & 1 deletion merlin/models/tf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,15 @@ def get_candidates_dataset(
model_encode = TFModelEncode(model=block, output_concat_func=np.concatenate)

data = data.to_ddf()
embedding_ddf = data.map_partitions(model_encode, filter_input_columns=[id_column])

# Processing a sample of the dataset with the model encoder
# to get the output dataframe dtypes
sample_output = model_encode(data.head(), filter_input_columns=[id_column])
output_dtypes = sample_output.dtypes.to_dict()

embedding_ddf = data.map_partitions(
model_encode, meta=output_dtypes, filter_input_columns=[id_column]
)
embedding_df = embedding_ddf.compute(scheduler="synchronous")

embedding_df.set_index(id_column, inplace=True)
Expand Down
24 changes: 21 additions & 3 deletions merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,7 +1579,13 @@ def batch_predict(
from merlin.models.tf.utils.batch_utils import TFModelEncode

model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs)
predictions = dataset.map_partitions(model_encode)

# Processing a sample of the dataset with the model encoder
# to get the output dataframe dtypes
sample_output = model_encode(dataset.head())
output_dtypes = sample_output.dtypes.to_dict()

predictions = dataset.map_partitions(model_encode, meta=output_dtypes)

return merlin.io.Dataset(predictions)

Expand Down Expand Up @@ -2354,7 +2360,13 @@ def query_embeddings(
get_user_emb = QueryEmbeddings(self, batch_size=batch_size)

dataset = unique_rows_by_features(dataset, query_tag, query_id_tag).to_ddf()
embeddings = dataset.map_partitions(get_user_emb)

# Processing a sample of the dataset with the model encoder
# to get the output dataframe dtypes
sample_output = get_user_emb(dataset.head())
output_dtypes = sample_output.dtypes.to_dict()

embeddings = dataset.map_partitions(get_user_emb, meta=output_dtypes)

return merlin.io.Dataset(embeddings)

Expand Down Expand Up @@ -2389,7 +2401,13 @@ def item_embeddings(
get_item_emb = ItemEmbeddings(self, batch_size=batch_size)

dataset = unique_rows_by_features(dataset, item_tag, item_id_tag).to_ddf()
embeddings = dataset.map_partitions(get_item_emb)

# Processing a sample of the dataset with the model encoder
# to get the output dataframe dtypes
sample_output = get_item_emb(dataset.head())
output_dtypes = sample_output.dtypes.to_dict()

embeddings = dataset.map_partitions(get_item_emb, meta=output_dtypes)

return merlin.io.Dataset(embeddings)

Expand Down
12 changes: 2 additions & 10 deletions merlin/models/tf/utils/batch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from merlin.models.tf.core.base import Block
from merlin.models.tf.loader import Loader
from merlin.models.tf.models.base import Model, RetrievalModel, get_task_names_from_outputs
from merlin.models.utils.schema_utils import select_targets
from merlin.schema import Schema, Tags
from merlin.schema import Schema


class ModelEncode:
Expand Down Expand Up @@ -176,17 +175,10 @@ def encode_output(output: tf.Tensor):
def data_iterator_func(schema, batch_size: int = 512):
import merlin.io.dataset

cat_cols = schema.select_by_tag(Tags.CATEGORICAL).excluding_by_tag(Tags.TARGET).column_names
cont_cols = schema.select_by_tag(Tags.CONTINUOUS).excluding_by_tag(Tags.TARGET).column_names
targets = select_targets(schema).column_names

def data_iterator(dataset):
return Loader(
merlin.io.dataset.Dataset(dataset),
merlin.io.dataset.Dataset(dataset, schema=schema),
batch_size=batch_size,
cat_names=cat_cols,
cont_names=cont_cols,
label_names=targets,
shuffle=False,
)

Expand Down
32 changes: 29 additions & 3 deletions tests/unit/tf/models/test_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ def test_youtube_dnn_retrieval_v2(sequence_testing_data: Dataset, run_eagerly, t
assert losses is not None


def test_two_tower_v2_export_embeddings(
def test_two_tower_v2_export_item_tower_embeddings(
ecommerce_data: Dataset,
):
user_schema = ecommerce_data.schema.select_by_tag(Tags.USER_ID)
Expand All @@ -907,7 +907,33 @@ def test_two_tower_v2_export_embeddings(
_check_embeddings(candidates, 100, 8, "item_id")


def test_mf_v2_export_embeddings(
def test_two_tower_v2_export_item_tower_embeddings_with_seq_item_features(
music_streaming_data: Dataset,
):
schema = music_streaming_data.schema
user_schema = schema.select_by_tag(Tags.USER)
candidate_schema = schema.select_by_tag(Tags.ITEM)

query = mm.Encoder(user_schema, mm.MLPBlock([8]))
candidate = mm.Encoder(candidate_schema, mm.MLPBlock([8]))
model = mm.TwoTowerModelV2(
query_tower=query, candidate_tower=candidate, negative_samplers=["in-batch"]
)

model, _ = testing_utils.model_test(model, music_streaming_data, reload_model=False)

queries = model.query_embeddings(
music_streaming_data, batch_size=16, index=Tags.USER_ID
).compute()
_check_embeddings(queries, 100, 8, "user_id")

candidates = model.candidate_embeddings(
music_streaming_data, batch_size=16, index=Tags.ITEM_ID
).compute()
_check_embeddings(candidates, 100, 8, "item_id")


def test_mf_v2_export_item_tower_embeddings(
ecommerce_data: Dataset,
):
model = mm.MatrixFactorizationModelV2(
Expand Down Expand Up @@ -939,7 +965,7 @@ def _check_embeddings(embeddings, extected_len, num_dim=8, index_name=None):
assert embeddings.index.name == index_name


def test_youtube_dnn_v2_export_embeddings(sequence_testing_data: Dataset):
def test_youtube_dnn_v2_export_item_embeddings(sequence_testing_data: Dataset):
to_remove = ["event_timestamp"] + (
sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE)
.select_by_tag(Tags.CONTINUOUS)
Expand Down

0 comments on commit 33fb8c4

Please sign in to comment.