Fixes retrieval encoders when query / item features have dense list features #1169
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Goals ⚽
This PR fixes the retrieval encoder methods (e.g.
to_top_k_model()
,batch_predict()
), that were failing in some cases, depending on the input features, e.g. multi-hot non-ragged item features.Implementation Details 🚧
RetrievalModelV2
) are composed by two towers, that encode item features and query/user features in separate towers. It allows for encoding the towers separately, generating the item or query embeddings.DataFrame.map_partitions()
to call the encoding function for every partition and generate the corresponding output of the encoding function (i.e., the output of the tower).meta
argument is not passed to DaskDataFrame.map_partitions()
, it generates some fake data base on the input dataframe schema to infer the output dataframe schema. But that may generate fake data that is different from the real data, in particular, the fake data generated for dense list columns (not-ragged), (e.g. multi-hot or embedding features), causes an error when the model encode function is called.meta
argument of theDataFrame.map_partitions()
by computing manually the expected output dataframe schema from a sample batch from real data in order to make the encoding more robust for different types of inputs.data_iterator_func()
that is used by the model encoder to use directly the schema rather than the oldLoader
arguments that setcategorical
,continuous
andtargets
separately, as the previous code did not deal correctly with list features.Testing Details 🔍
test_two_tower_v2_export_item_tower_embeddings_with_seq_item_features
test, that uses the music_streaming_data synthetic data and contains multi-hot list features (ragged and not ragged), for which the encoding functions were failing before this fix