diff --git a/merlin/datasets/entertainment/music_streaming/schema.json b/merlin/datasets/entertainment/music_streaming/schema.json index cdef74879a..bfe3a14530 100644 --- a/merlin/datasets/entertainment/music_streaming/schema.json +++ b/merlin/datasets/entertainment/music_streaming/schema.json @@ -98,7 +98,8 @@ "annotation": { "tag": [ "categorical", - "user_id" + "user_id", + "user" ] } }, diff --git a/merlin/models/tf/core/encoder.py b/merlin/models/tf/core/encoder.py index 0dd47b5187..3f32a73f67 100644 --- a/merlin/models/tf/core/encoder.py +++ b/merlin/models/tf/core/encoder.py @@ -31,6 +31,7 @@ from merlin.models.tf.outputs.topk import TopKOutput from merlin.models.tf.transforms.features import PrepareFeatures from merlin.models.tf.utils import tf_utils +from merlin.models.tf.utils.batch_utils import TFModelEncode from merlin.schema import ColumnSchema, Schema, Tags @@ -171,13 +172,18 @@ def batch_predict( if hasattr(dataset, "to_ddf"): dataset = dataset.to_ddf() - from merlin.models.tf.utils.batch_utils import TFModelEncode - model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs) + encode_kwargs = {} if output_schema: encode_kwargs["filter_input_columns"] = output_schema.column_names - predictions = dataset.map_partitions(model_encode, **encode_kwargs) + + # Processing a sample of the dataset with the model encoder + # to get the output dataframe dtypes + sample_output = model_encode(dataset.head(), **encode_kwargs) + output_dtypes = sample_output.dtypes.to_dict() + + predictions = dataset.map_partitions(model_encode, meta=output_dtypes, **encode_kwargs) if index: predictions = predictions.set_index(index) @@ -613,7 +619,12 @@ def batch_predict( if output_schema: encode_kwargs["filter_input_columns"] = output_schema.column_names - predictions = dataset.map_partitions(model_encode, **encode_kwargs) + # Processing a sample of the dataset with the model encoder + # to get the output dataframe dtypes + sample_output = model_encode(dataset.head(), **encode_kwargs) + output_dtypes = sample_output.dtypes.to_dict() + + predictions = dataset.map_partitions(model_encode, meta=output_dtypes, **encode_kwargs) return merlin.io.Dataset(predictions) diff --git a/merlin/models/tf/core/index.py b/merlin/models/tf/core/index.py index fc36f1a114..ba2632f70e 100644 --- a/merlin/models/tf/core/index.py +++ b/merlin/models/tf/core/index.py @@ -113,7 +113,15 @@ def get_candidates_dataset( model_encode = TFModelEncode(model=block, output_concat_func=np.concatenate) data = data.to_ddf() - embedding_ddf = data.map_partitions(model_encode, filter_input_columns=[id_column]) + + # Processing a sample of the dataset with the model encoder + # to get the output dataframe dtypes + sample_output = model_encode(data.head(), filter_input_columns=[id_column]) + output_dtypes = sample_output.dtypes.to_dict() + + embedding_ddf = data.map_partitions( + model_encode, meta=output_dtypes, filter_input_columns=[id_column] + ) embedding_df = embedding_ddf.compute(scheduler="synchronous") embedding_df.set_index(id_column, inplace=True) diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index dd8e96a440..365c8fe9d9 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -1579,7 +1579,13 @@ def batch_predict( from merlin.models.tf.utils.batch_utils import TFModelEncode model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs) - predictions = dataset.map_partitions(model_encode) + + # Processing a sample of the dataset with the model encoder + # to get the output dataframe dtypes + sample_output = model_encode(dataset.head()) + output_dtypes = sample_output.dtypes.to_dict() + + predictions = dataset.map_partitions(model_encode, meta=output_dtypes) return merlin.io.Dataset(predictions) @@ -2354,7 +2360,13 @@ def query_embeddings( get_user_emb = QueryEmbeddings(self, batch_size=batch_size) dataset = unique_rows_by_features(dataset, query_tag, query_id_tag).to_ddf() - embeddings = dataset.map_partitions(get_user_emb) + + # Processing a sample of the dataset with the model encoder + # to get the output dataframe dtypes + sample_output = get_user_emb(dataset.head()) + output_dtypes = sample_output.dtypes.to_dict() + + embeddings = dataset.map_partitions(get_user_emb, meta=output_dtypes) return merlin.io.Dataset(embeddings) @@ -2389,7 +2401,13 @@ def item_embeddings( get_item_emb = ItemEmbeddings(self, batch_size=batch_size) dataset = unique_rows_by_features(dataset, item_tag, item_id_tag).to_ddf() - embeddings = dataset.map_partitions(get_item_emb) + + # Processing a sample of the dataset with the model encoder + # to get the output dataframe dtypes + sample_output = get_item_emb(dataset.head()) + output_dtypes = sample_output.dtypes.to_dict() + + embeddings = dataset.map_partitions(get_item_emb, meta=output_dtypes) return merlin.io.Dataset(embeddings) diff --git a/merlin/models/tf/utils/batch_utils.py b/merlin/models/tf/utils/batch_utils.py index 9ffce6b649..bc48da9ddd 100644 --- a/merlin/models/tf/utils/batch_utils.py +++ b/merlin/models/tf/utils/batch_utils.py @@ -8,8 +8,7 @@ from merlin.models.tf.core.base import Block from merlin.models.tf.loader import Loader from merlin.models.tf.models.base import Model, RetrievalModel, get_task_names_from_outputs -from merlin.models.utils.schema_utils import select_targets -from merlin.schema import Schema, Tags +from merlin.schema import Schema class ModelEncode: @@ -176,17 +175,10 @@ def encode_output(output: tf.Tensor): def data_iterator_func(schema, batch_size: int = 512): import merlin.io.dataset - cat_cols = schema.select_by_tag(Tags.CATEGORICAL).excluding_by_tag(Tags.TARGET).column_names - cont_cols = schema.select_by_tag(Tags.CONTINUOUS).excluding_by_tag(Tags.TARGET).column_names - targets = select_targets(schema).column_names - def data_iterator(dataset): return Loader( - merlin.io.dataset.Dataset(dataset), + merlin.io.dataset.Dataset(dataset, schema=schema), batch_size=batch_size, - cat_names=cat_cols, - cont_names=cont_cols, - label_names=targets, shuffle=False, ) diff --git a/tests/unit/tf/models/test_retrieval.py b/tests/unit/tf/models/test_retrieval.py index 4fd51525bd..252a666fba 100644 --- a/tests/unit/tf/models/test_retrieval.py +++ b/tests/unit/tf/models/test_retrieval.py @@ -884,7 +884,7 @@ def test_youtube_dnn_retrieval_v2(sequence_testing_data: Dataset, run_eagerly, t assert losses is not None -def test_two_tower_v2_export_embeddings( +def test_two_tower_v2_export_item_tower_embeddings( ecommerce_data: Dataset, ): user_schema = ecommerce_data.schema.select_by_tag(Tags.USER_ID) @@ -907,7 +907,38 @@ def test_two_tower_v2_export_embeddings( _check_embeddings(candidates, 100, 8, "item_id") -def test_mf_v2_export_embeddings( +def test_two_tower_v2_export_item_tower_embeddings_with_seq_item_features( + music_streaming_data: Dataset, +): + # Changing the schema of the multi-hot "item_genres" feature to be + # dense (not ragged) + music_streaming_data.schema["item_genres"] = music_streaming_data.schema[ + "item_genres" + ].with_shape(((0, None), (4, 4))) + schema = music_streaming_data.schema + user_schema = schema.select_by_tag(Tags.USER) + candidate_schema = schema.select_by_tag(Tags.ITEM) + + query = mm.Encoder(user_schema, mm.MLPBlock([8])) + candidate = mm.Encoder(candidate_schema, mm.MLPBlock([8])) + model = mm.TwoTowerModelV2( + query_tower=query, candidate_tower=candidate, negative_samplers=["in-batch"] + ) + + model, _ = testing_utils.model_test(model, music_streaming_data, reload_model=False) + + queries = model.query_embeddings( + music_streaming_data, batch_size=16, index=Tags.USER_ID + ).compute() + _check_embeddings(queries, 100, 8, "user_id") + + candidates = model.candidate_embeddings( + music_streaming_data, batch_size=16, index=Tags.ITEM_ID + ).compute() + _check_embeddings(candidates, 100, 8, "item_id") + + +def test_mf_v2_export_item_tower_embeddings( ecommerce_data: Dataset, ): model = mm.MatrixFactorizationModelV2( @@ -939,7 +970,7 @@ def _check_embeddings(embeddings, extected_len, num_dim=8, index_name=None): assert embeddings.index.name == index_name -def test_youtube_dnn_v2_export_embeddings(sequence_testing_data: Dataset): +def test_youtube_dnn_v2_export_item_embeddings(sequence_testing_data: Dataset): to_remove = ["event_timestamp"] + ( sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE) .select_by_tag(Tags.CONTINUOUS)