-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DLRM block #1162
Merged
Merged
Add DLRM block #1162
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
3fd6980
First pass over DLRM-related blocks
marcromeyn 73a2b62
add a test for dlrm block and fixes to make test pass
edknv c9d43da
Add unit tests
edknv 790835b
Initialize default metrics/loss inside ModelOutput instead
marcromeyn 233b477
Merge remote-tracking branch 'upstream/torch/default-loss-metrics' in…
edknv 9610618
Merge branch 'main' into torch/dlrm
edknv 6462587
Merge branch 'torch/dlrm' of github.com:NVIDIA-Merlin/models into tor…
edknv 03a60b5
Update merlin/models/torch/blocks/dlrm.py
edknv 3bae9b6
Merge branch 'torch/dlrm' of github.com:NVIDIA-Merlin/models into tor…
edknv 89aeecc
Merge branch 'main' into torch/dlrm
edknv 26ad78a
fixes and changes for failing tests
edknv 7253078
Merge branch 'main' into torch/dlrm
edknv b5c0e3f
pass batch to model test
edknv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from merlin.models.torch.block import Block | ||
from merlin.models.torch.inputs.embedding import EmbeddingTables | ||
from merlin.models.torch.inputs.tabular import TabularInputBlock | ||
from merlin.models.torch.link import Link | ||
from merlin.models.torch.transforms.agg import MaybeAgg, Stack | ||
from merlin.models.utils.doc_utils import docstring_parameter | ||
from merlin.schema import Schema, Tags | ||
|
||
_DLRM_REF = """ | ||
References | ||
---------- | ||
.. [1] Naumov, Maxim, et al. "Deep learning recommendation model for | ||
personalization and recommendation systems." arXiv preprint arXiv:1906.00091 (2019). | ||
""" | ||
|
||
|
||
@docstring_parameter(dlrm_reference=_DLRM_REF) | ||
class DLRMInputBlock(TabularInputBlock): | ||
"""Input block for DLRM model. | ||
|
||
Parameters | ||
---------- | ||
schema : Schema, optional | ||
The schema to use for selection. Default is None. | ||
dim : int | ||
The dimensionality of the output vectors. | ||
bottom_block : Block | ||
Block to pass the continuous features to. | ||
Note that, the output dimensionality of this block must be equal to ``dim``. | ||
|
||
{dlrm_reference} | ||
|
||
Raises | ||
------ | ||
ValueError | ||
If no categorical input is provided in the schema. | ||
|
||
""" | ||
|
||
def __init__(self, schema: Schema, dim: int, bottom_block: Block): | ||
super().__init__(schema) | ||
self.add_route(Tags.CATEGORICAL, EmbeddingTables(dim, seq_combiner="mean")) | ||
self.add_route(Tags.CONTINUOUS, bottom_block) | ||
|
||
if "categorical" not in self: | ||
raise ValueError("DLRMInputBlock must have a categorical input") | ||
|
||
|
||
@docstring_parameter(dlrm_reference=_DLRM_REF) | ||
class DLRMInteraction(nn.Module): | ||
""" | ||
This class defines the forward interaction operation as proposed | ||
in the DLRM | ||
`paper https://arxiv.org/pdf/1906.00091.pdf`_ [1]_. | ||
|
||
This forward operation performs elementwise multiplication | ||
followed by a reduction sum (equivalent to a dot product) of all embedding pairs. | ||
|
||
{dlrm_reference} | ||
|
||
""" | ||
|
||
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | ||
if not hasattr(self, "triu_indices"): | ||
self.register_buffer( | ||
"triu_indices", torch.triu_indices(inputs.shape[1], inputs.shape[1], offset=1) | ||
) | ||
|
||
interactions = torch.bmm(inputs, torch.transpose(inputs, 1, 2)) | ||
interactions_flat = interactions[:, self.triu_indices[0], self.triu_indices[1]] | ||
|
||
return interactions_flat | ||
|
||
|
||
class ShortcutConcatContinuous(Link): | ||
""" | ||
A shortcut connection that concatenates | ||
continuous input features and intermediate outputs. | ||
|
||
When there's no continuous input, the intermediate output is returned. | ||
""" | ||
|
||
def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: | ||
intermediate_output = self.output(inputs) | ||
|
||
if "continuous" in inputs: | ||
return torch.cat((inputs["continuous"], intermediate_output), dim=1) | ||
|
||
return intermediate_output | ||
|
||
|
||
@docstring_parameter(dlrm_reference=_DLRM_REF) | ||
class DLRMBlock(Block): | ||
"""Builds the DLRM architecture, as proposed in the following | ||
`paper https://arxiv.org/pdf/1906.00091.pdf`_ [1]_. | ||
|
||
Parameters | ||
---------- | ||
schema : Schema, optional | ||
The schema to use for selection. Default is None. | ||
dim : int | ||
The dimensionality of the output vectors. | ||
bottom_block : Block | ||
Block to pass the continuous features to. | ||
Note that, the output dimensionality of this block must be equal to ``dim``. | ||
top_block : Block, optional | ||
An optional upper-level block of the model. | ||
interaction : nn.Module, optional | ||
Interaction module for DLRM. | ||
If not provided, DLRMInteraction will be used by default. | ||
|
||
{dlrm_reference} | ||
|
||
Raises | ||
------ | ||
ValueError | ||
If no categorical input is provided in the schema. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
schema: Schema, | ||
dim: int, | ||
bottom_block: Block, | ||
top_block: Optional[Block] = None, | ||
interaction: Optional[nn.Module] = None, | ||
): | ||
super().__init__(DLRMInputBlock(schema, dim, bottom_block)) | ||
|
||
self.append( | ||
Block(MaybeAgg(Stack(dim=1)), interaction or DLRMInteraction()), | ||
link=ShortcutConcatContinuous(), | ||
) | ||
|
||
if top_block: | ||
self.append(top_block) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import math | ||
|
||
import pytest | ||
import torch | ||
|
||
import merlin.models.torch as mm | ||
from merlin.models.torch.batch import sample_batch | ||
from merlin.models.torch.blocks.dlrm import DLRMInputBlock, DLRMInteraction | ||
from merlin.models.torch.utils import module_utils | ||
from merlin.schema import Tags | ||
|
||
|
||
class TestDLRMInputBlock: | ||
def test_routes_and_output_shapes(self, testing_data): | ||
schema = testing_data.schema | ||
embedding_dim = 64 | ||
block = DLRMInputBlock(schema, embedding_dim, mm.MLPBlock([embedding_dim])) | ||
|
||
assert isinstance(block["categorical"], mm.EmbeddingTables) | ||
assert len(block["categorical"]) == len(schema.select_by_tag(Tags.CATEGORICAL)) | ||
|
||
assert isinstance(block["continuous"][0], mm.SelectKeys) | ||
assert isinstance(block["continuous"][1], mm.MLPBlock) | ||
|
||
batch_size = 16 | ||
batch = sample_batch(testing_data, batch_size=batch_size) | ||
|
||
outputs = module_utils.module_test(block, batch) | ||
|
||
for col in schema.select_by_tag(Tags.CATEGORICAL): | ||
assert outputs[col.name].shape == (batch_size, embedding_dim) | ||
assert outputs["continuous"].shape == (batch_size, embedding_dim) | ||
|
||
|
||
class TestDLRMInteraction: | ||
@pytest.mark.parametrize( | ||
"batch_size,num_features,dim", | ||
[(16, 3, 3), (32, 5, 8), (64, 5, 4)], | ||
) | ||
def test_output_shape(self, batch_size, num_features, dim): | ||
module = DLRMInteraction() | ||
inputs = torch.rand((batch_size, num_features, dim)) | ||
outputs = module_utils.module_test(module, inputs) | ||
|
||
assert outputs.shape == (batch_size, num_features - 1 + math.comb(num_features - 1, 2)) | ||
|
||
|
||
class TestDLRMBlock: | ||
@pytest.fixture(autouse=True) | ||
def setup_method(self, testing_data): | ||
self.schema = testing_data.schema | ||
self.batch_size = 16 | ||
self.batch = sample_batch(testing_data, batch_size=self.batch_size) | ||
|
||
def test_dlrm_output_shape(self): | ||
embedding_dim = 64 | ||
block = mm.DLRMBlock( | ||
self.schema, | ||
dim=embedding_dim, | ||
bottom_block=mm.MLPBlock([embedding_dim]), | ||
) | ||
|
||
outputs = module_utils.module_test(block, self.batch) | ||
|
||
num_features = len(self.schema.select_by_tag(Tags.CATEGORICAL)) + 1 | ||
dot_product_dim = (num_features - 1) * num_features // 2 | ||
assert list(outputs.shape) == [self.batch_size, dot_product_dim + embedding_dim] | ||
|
||
def test_dlrm_with_top_block(self): | ||
embedding_dim = 32 | ||
top_block_dim = 8 | ||
block = mm.DLRMBlock( | ||
self.schema, | ||
dim=embedding_dim, | ||
bottom_block=mm.MLPBlock([embedding_dim]), | ||
top_block=mm.MLPBlock([top_block_dim]), | ||
) | ||
|
||
outputs = module_utils.module_test(block, self.batch) | ||
|
||
assert list(outputs.shape) == [self.batch_size, top_block_dim] | ||
|
||
def test_dlrm_block_no_categorical_features(self): | ||
schema = self.schema.remove_by_tag(Tags.CATEGORICAL) | ||
embedding_dim = 32 | ||
|
||
with pytest.raises(ValueError, match="must have a categorical input"): | ||
_ = mm.DLRMBlock( | ||
schema, | ||
dim=embedding_dim, | ||
bottom_block=mm.MLPBlock([embedding_dim]), | ||
) | ||
|
||
def test_dlrm_block_no_continuous_features(self, testing_data): | ||
schema = testing_data.schema.remove_by_tag(Tags.CONTINUOUS) | ||
testing_data.schema = schema | ||
|
||
embedding_dim = 32 | ||
block = mm.DLRMBlock( | ||
schema, | ||
dim=embedding_dim, | ||
bottom_block=mm.MLPBlock([embedding_dim]), | ||
) | ||
|
||
batch_size = 16 | ||
batch = sample_batch(testing_data, batch_size=batch_size) | ||
|
||
outputs = module_utils.module_test(block, batch) | ||
|
||
num_features = len(schema.select_by_tag(Tags.CATEGORICAL)) | ||
dot_product_dim = (num_features - 1) * num_features // 2 | ||
assert list(outputs.shape) == [batch_size, dot_product_dim] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,8 @@ def test_exceptions(self): | |
with pytest.raises(ValueError, match="not found"): | ||
mm.TabularOutputBlock(self.schema, init="not_found") | ||
|
||
def test_no_route_for_non_existent_tag(self): | ||
outputs = mm.TabularOutputBlock(self.schema) | ||
with pytest.raises(ValueError): | ||
outputs.add_route(Tags.CATEGORICAL) | ||
outputs.add_route(Tags.CATEGORICAL) | ||
Comment on lines
-33
to
+34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was changed to not raise an error because it is possible that there is no continuous features in the DLRM block. Alternatively, we could have a try-except in the DLRM block, but this felt more natural to me. |
||
|
||
assert not outputs |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test case was incorrect because
feature = [[1.0, 2.0], [3.0, 4.0]]
will produce a list column.