Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes retrieval encoders when query / item features have dense list features #1169

Merged
merged 3 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion merlin/datasets/entertainment/music_streaming/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@
"annotation": {
"tag": [
"categorical",
"user_id"
"user_id",
"user"
]
}
},
Expand Down
19 changes: 15 additions & 4 deletions merlin/models/tf/core/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from merlin.models.tf.outputs.topk import TopKOutput
from merlin.models.tf.transforms.features import PrepareFeatures
from merlin.models.tf.utils import tf_utils
from merlin.models.tf.utils.batch_utils import TFModelEncode
from merlin.schema import ColumnSchema, Schema, Tags


Expand Down Expand Up @@ -171,13 +172,18 @@ def batch_predict(
if hasattr(dataset, "to_ddf"):
dataset = dataset.to_ddf()

from merlin.models.tf.utils.batch_utils import TFModelEncode

model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs)

encode_kwargs = {}
if output_schema:
encode_kwargs["filter_input_columns"] = output_schema.column_names
predictions = dataset.map_partitions(model_encode, **encode_kwargs)

# Processing a sample of the dataset with the model encoder
# to get the output dataframe dtypes
sample_output = model_encode(dataset.head(), **encode_kwargs)
output_dtypes = sample_output.dtypes.to_dict()

predictions = dataset.map_partitions(model_encode, meta=output_dtypes, **encode_kwargs)
if index:
predictions = predictions.set_index(index)

Expand Down Expand Up @@ -613,7 +619,12 @@ def batch_predict(
if output_schema:
encode_kwargs["filter_input_columns"] = output_schema.column_names

predictions = dataset.map_partitions(model_encode, **encode_kwargs)
# Processing a sample of the dataset with the model encoder
# to get the output dataframe dtypes
sample_output = model_encode(dataset.head(), **encode_kwargs)
output_dtypes = sample_output.dtypes.to_dict()

predictions = dataset.map_partitions(model_encode, meta=output_dtypes, **encode_kwargs)

return merlin.io.Dataset(predictions)

Expand Down
10 changes: 9 additions & 1 deletion merlin/models/tf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,15 @@ def get_candidates_dataset(
model_encode = TFModelEncode(model=block, output_concat_func=np.concatenate)

data = data.to_ddf()
embedding_ddf = data.map_partitions(model_encode, filter_input_columns=[id_column])

# Processing a sample of the dataset with the model encoder
# to get the output dataframe dtypes
sample_output = model_encode(data.head(), filter_input_columns=[id_column])
output_dtypes = sample_output.dtypes.to_dict()

embedding_ddf = data.map_partitions(
model_encode, meta=output_dtypes, filter_input_columns=[id_column]
)
embedding_df = embedding_ddf.compute(scheduler="synchronous")

embedding_df.set_index(id_column, inplace=True)
Expand Down
24 changes: 21 additions & 3 deletions merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,7 +1579,13 @@ def batch_predict(
from merlin.models.tf.utils.batch_utils import TFModelEncode

model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs)
predictions = dataset.map_partitions(model_encode)

# Processing a sample of the dataset with the model encoder
# to get the output dataframe dtypes
sample_output = model_encode(dataset.head())
output_dtypes = sample_output.dtypes.to_dict()

predictions = dataset.map_partitions(model_encode, meta=output_dtypes)

return merlin.io.Dataset(predictions)

Expand Down Expand Up @@ -2354,7 +2360,13 @@ def query_embeddings(
get_user_emb = QueryEmbeddings(self, batch_size=batch_size)

dataset = unique_rows_by_features(dataset, query_tag, query_id_tag).to_ddf()
embeddings = dataset.map_partitions(get_user_emb)

# Processing a sample of the dataset with the model encoder
# to get the output dataframe dtypes
sample_output = get_user_emb(dataset.head())
output_dtypes = sample_output.dtypes.to_dict()

embeddings = dataset.map_partitions(get_user_emb, meta=output_dtypes)

return merlin.io.Dataset(embeddings)

Expand Down Expand Up @@ -2389,7 +2401,13 @@ def item_embeddings(
get_item_emb = ItemEmbeddings(self, batch_size=batch_size)

dataset = unique_rows_by_features(dataset, item_tag, item_id_tag).to_ddf()
embeddings = dataset.map_partitions(get_item_emb)

# Processing a sample of the dataset with the model encoder
# to get the output dataframe dtypes
sample_output = get_item_emb(dataset.head())
output_dtypes = sample_output.dtypes.to_dict()

embeddings = dataset.map_partitions(get_item_emb, meta=output_dtypes)

return merlin.io.Dataset(embeddings)

Expand Down
12 changes: 2 additions & 10 deletions merlin/models/tf/utils/batch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from merlin.models.tf.core.base import Block
from merlin.models.tf.loader import Loader
from merlin.models.tf.models.base import Model, RetrievalModel, get_task_names_from_outputs
from merlin.models.utils.schema_utils import select_targets
from merlin.schema import Schema, Tags
from merlin.schema import Schema


class ModelEncode:
Expand Down Expand Up @@ -176,17 +175,10 @@ def encode_output(output: tf.Tensor):
def data_iterator_func(schema, batch_size: int = 512):
import merlin.io.dataset

cat_cols = schema.select_by_tag(Tags.CATEGORICAL).excluding_by_tag(Tags.TARGET).column_names
cont_cols = schema.select_by_tag(Tags.CONTINUOUS).excluding_by_tag(Tags.TARGET).column_names
targets = select_targets(schema).column_names

def data_iterator(dataset):
return Loader(
merlin.io.dataset.Dataset(dataset),
merlin.io.dataset.Dataset(dataset, schema=schema),
batch_size=batch_size,
cat_names=cat_cols,
cont_names=cont_cols,
label_names=targets,
shuffle=False,
)

Expand Down
37 changes: 34 additions & 3 deletions tests/unit/tf/models/test_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ def test_youtube_dnn_retrieval_v2(sequence_testing_data: Dataset, run_eagerly, t
assert losses is not None


def test_two_tower_v2_export_embeddings(
def test_two_tower_v2_export_item_tower_embeddings(
ecommerce_data: Dataset,
):
user_schema = ecommerce_data.schema.select_by_tag(Tags.USER_ID)
Expand All @@ -907,7 +907,38 @@ def test_two_tower_v2_export_embeddings(
_check_embeddings(candidates, 100, 8, "item_id")


def test_mf_v2_export_embeddings(
def test_two_tower_v2_export_item_tower_embeddings_with_seq_item_features(
music_streaming_data: Dataset,
):
# Changing the schema of the multi-hot "item_genres" feature to be
# dense (not ragged)
music_streaming_data.schema["item_genres"] = music_streaming_data.schema[
"item_genres"
].with_shape(((0, None), (4, 4)))
schema = music_streaming_data.schema
user_schema = schema.select_by_tag(Tags.USER)
candidate_schema = schema.select_by_tag(Tags.ITEM)

query = mm.Encoder(user_schema, mm.MLPBlock([8]))
candidate = mm.Encoder(candidate_schema, mm.MLPBlock([8]))
model = mm.TwoTowerModelV2(
query_tower=query, candidate_tower=candidate, negative_samplers=["in-batch"]
)

model, _ = testing_utils.model_test(model, music_streaming_data, reload_model=False)

queries = model.query_embeddings(
music_streaming_data, batch_size=16, index=Tags.USER_ID
).compute()
_check_embeddings(queries, 100, 8, "user_id")

candidates = model.candidate_embeddings(
music_streaming_data, batch_size=16, index=Tags.ITEM_ID
).compute()
_check_embeddings(candidates, 100, 8, "item_id")


def test_mf_v2_export_item_tower_embeddings(
ecommerce_data: Dataset,
):
model = mm.MatrixFactorizationModelV2(
Expand Down Expand Up @@ -939,7 +970,7 @@ def _check_embeddings(embeddings, extected_len, num_dim=8, index_name=None):
assert embeddings.index.name == index_name


def test_youtube_dnn_v2_export_embeddings(sequence_testing_data: Dataset):
def test_youtube_dnn_v2_export_item_embeddings(sequence_testing_data: Dataset):
to_remove = ["event_timestamp"] + (
sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE)
.select_by_tag(Tags.CONTINUOUS)
Expand Down