Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DLRM block #1162

Merged
merged 13 commits into from
Jul 1, 2023
2 changes: 2 additions & 0 deletions merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from merlin.models.torch import schema
from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.block import Block, ParallelBlock
from merlin.models.torch.blocks.dlrm import DLRMBlock
from merlin.models.torch.blocks.mlp import MLPBlock
from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables
from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys
Expand Down Expand Up @@ -51,4 +52,5 @@
"Concat",
"Stack",
"schema",
"DLRMBlock",
]
141 changes: 141 additions & 0 deletions merlin/models/torch/blocks/dlrm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from typing import Dict, Optional

import torch
from torch import nn

from merlin.models.torch.block import Block
from merlin.models.torch.inputs.embedding import EmbeddingTables
from merlin.models.torch.inputs.tabular import TabularInputBlock
from merlin.models.torch.link import Link
from merlin.models.torch.transforms.agg import MaybeAgg, Stack
from merlin.models.utils.doc_utils import docstring_parameter
from merlin.schema import Schema, Tags

_DLRM_REF = """
References
----------
.. [1] Naumov, Maxim, et al. "Deep learning recommendation model for
personalization and recommendation systems." arXiv preprint arXiv:1906.00091 (2019).
"""


@docstring_parameter(dlrm_reference=_DLRM_REF)
class DLRMInputBlock(TabularInputBlock):
"""Input block for DLRM model.

Parameters
----------
schema : Schema, optional
The schema to use for selection. Default is None.
dim : int
The dimensionality of the output vectors.
bottom_block : Block
Block to pass the continuous features to.
Note that, the output dimensionality of this block must be equal to ``dim``.

{dlrm_reference}

Raises
------
ValueError
If no categorical input is provided in the schema.

"""

def __init__(self, schema: Schema, dim: int, bottom_block: Block):
super().__init__(schema)
self.add_route(Tags.CATEGORICAL, EmbeddingTables(dim, seq_combiner="mean"))
self.add_route(Tags.CONTINUOUS, bottom_block)

if "categorical" not in self:
raise ValueError("DLRMInputBlock must have a categorical input")


@docstring_parameter(dlrm_reference=_DLRM_REF)
class DLRMInteraction(nn.Module):
"""
This class defines the forward interaction operation as proposed
in the DLRM
`paper https://arxiv.org/pdf/1906.00091.pdf`_ [1]_.

This forward operation performs elementwise multiplication
followed by a reduction sum (equivalent to a dot product) of all embedding pairs.

{dlrm_reference}

"""

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if not hasattr(self, "triu_indices"):
self.register_buffer(
"triu_indices", torch.triu_indices(inputs.shape[1], inputs.shape[1], offset=1)
)

interactions = torch.bmm(inputs, torch.transpose(inputs, 1, 2))
interactions_flat = interactions[:, self.triu_indices[0], self.triu_indices[1]]

return interactions_flat


class ShortcutConcatContinuous(Link):
"""
A shortcut connection that concatenates
continuous input features and intermediate outputs.

When there's no continuous input, the intermediate output is returned.
"""

def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
intermediate_output = self.output(inputs)

if "continuous" in inputs:
return torch.cat((inputs["continuous"], intermediate_output), dim=1)

return intermediate_output


@docstring_parameter(dlrm_reference=_DLRM_REF)
class DLRMBlock(Block):
"""Builds the DLRM architecture, as proposed in the following
`paper https://arxiv.org/pdf/1906.00091.pdf`_ [1]_.

Parameters
----------
schema : Schema, optional
The schema to use for selection. Default is None.
dim : int
The dimensionality of the output vectors.
bottom_block : Block
Block to pass the continuous features to.
Note that, the output dimensionality of this block must be equal to ``dim``.
top_block : Block, optional
An optional upper-level block of the model.
interaction : nn.Module, optional
Interaction module for DLRM.
If not provided, DLRMInteraction will be used by default.

{dlrm_reference}

Raises
------
ValueError
If no categorical input is provided in the schema.
"""

def __init__(
self,
schema: Schema,
dim: int,
bottom_block: Block,
top_block: Optional[Block] = None,
interaction: Optional[nn.Module] = None,
):
super().__init__(DLRMInputBlock(schema, dim, bottom_block))

self.append(
Block(MaybeAgg(Stack(dim=1)), interaction or DLRMInteraction()),
link=ShortcutConcatContinuous(),
)

if top_block:
self.append(top_block)
3 changes: 3 additions & 0 deletions merlin/models/torch/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def add_route(
"""

routing_module = schema.select(self.selectable, selection)
if not routing_module:
return self

if module is not None:
schema.setup_schema(module, routing_module.schema)

Expand Down
3 changes: 3 additions & 0 deletions merlin/models/torch/transforms/agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:

if self.align_dims:
max_dims = max(tensor.dim() for tensor in sorted_tensors)
max_dims = max(
max_dims, 2
) # assume first dimension is batch size + at least one feature
_sorted_tensors = []
for tensor in sorted_tensors:
if tensor.dim() < max_dims:
Expand Down
112 changes: 112 additions & 0 deletions tests/unit/torch/blocks/test_dlrm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import math

import pytest
import torch

import merlin.models.torch as mm
from merlin.models.torch.batch import sample_batch
from merlin.models.torch.blocks.dlrm import DLRMInputBlock, DLRMInteraction
from merlin.models.torch.utils import module_utils
from merlin.schema import Tags


class TestDLRMInputBlock:
def test_routes_and_output_shapes(self, testing_data):
schema = testing_data.schema
embedding_dim = 64
block = DLRMInputBlock(schema, embedding_dim, mm.MLPBlock([embedding_dim]))

assert isinstance(block["categorical"], mm.EmbeddingTables)
assert len(block["categorical"]) == len(schema.select_by_tag(Tags.CATEGORICAL))

assert isinstance(block["continuous"][0], mm.SelectKeys)
assert isinstance(block["continuous"][1], mm.MLPBlock)

batch_size = 16
batch = sample_batch(testing_data, batch_size=batch_size)

outputs = module_utils.module_test(block, batch)

for col in schema.select_by_tag(Tags.CATEGORICAL):
assert outputs[col.name].shape == (batch_size, embedding_dim)
assert outputs["continuous"].shape == (batch_size, embedding_dim)


class TestDLRMInteraction:
@pytest.mark.parametrize(
"batch_size,num_features,dim",
[(16, 3, 3), (32, 5, 8), (64, 5, 4)],
)
def test_output_shape(self, batch_size, num_features, dim):
module = DLRMInteraction()
inputs = torch.rand((batch_size, num_features, dim))
outputs = module_utils.module_test(module, inputs)

assert outputs.shape == (batch_size, num_features - 1 + math.comb(num_features - 1, 2))


class TestDLRMBlock:
@pytest.fixture(autouse=True)
def setup_method(self, testing_data):
self.schema = testing_data.schema
self.batch_size = 16
self.batch = sample_batch(testing_data, batch_size=self.batch_size)

def test_dlrm_output_shape(self):
embedding_dim = 64
block = mm.DLRMBlock(
self.schema,
dim=embedding_dim,
bottom_block=mm.MLPBlock([embedding_dim]),
)

outputs = module_utils.module_test(block, self.batch)

num_features = len(self.schema.select_by_tag(Tags.CATEGORICAL)) + 1
dot_product_dim = (num_features - 1) * num_features // 2
assert list(outputs.shape) == [self.batch_size, dot_product_dim + embedding_dim]

def test_dlrm_with_top_block(self):
embedding_dim = 32
top_block_dim = 8
block = mm.DLRMBlock(
self.schema,
dim=embedding_dim,
bottom_block=mm.MLPBlock([embedding_dim]),
top_block=mm.MLPBlock([top_block_dim]),
)

outputs = module_utils.module_test(block, self.batch)

assert list(outputs.shape) == [self.batch_size, top_block_dim]

def test_dlrm_block_no_categorical_features(self):
schema = self.schema.remove_by_tag(Tags.CATEGORICAL)
embedding_dim = 32

with pytest.raises(ValueError, match="must have a categorical input"):
_ = mm.DLRMBlock(
schema,
dim=embedding_dim,
bottom_block=mm.MLPBlock([embedding_dim]),
)

def test_dlrm_block_no_continuous_features(self, testing_data):
schema = testing_data.schema.remove_by_tag(Tags.CONTINUOUS)
testing_data.schema = schema

embedding_dim = 32
block = mm.DLRMBlock(
schema,
dim=embedding_dim,
bottom_block=mm.MLPBlock([embedding_dim]),
)

batch_size = 16
batch = sample_batch(testing_data, batch_size=batch_size)

outputs = module_utils.module_test(block, batch)

num_features = len(schema.select_by_tag(Tags.CATEGORICAL))
dot_product_dim = (num_features - 1) * num_features // 2
assert list(outputs.shape) == [batch_size, dot_product_dim]
6 changes: 3 additions & 3 deletions tests/unit/torch/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ def test_training_step_with_dataloader(self):
mm.BinaryOutput(ColumnSchema("target")),
)

feature = [[1.0, 2.0], [3.0, 4.0]]
target = [[0.0], [1.0]]
feature = [2.0, 3.0]
target = [0.0, 1.0]
dataset = Dataset(pd.DataFrame({"feature": feature, "target": target}))
Comment on lines -136 to 138
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case was incorrect because feature = [[1.0, 2.0], [3.0, 4.0]] will produce a list column.


with Loader(dataset, batch_size=1) as loader:
with Loader(dataset, batch_size=2) as loader:
model.initialize(loader)
batch = loader.peek()

Expand Down
6 changes: 4 additions & 2 deletions tests/unit/torch/outputs/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def test_exceptions(self):
with pytest.raises(ValueError, match="not found"):
mm.TabularOutputBlock(self.schema, init="not_found")

def test_no_route_for_non_existent_tag(self):
outputs = mm.TabularOutputBlock(self.schema)
with pytest.raises(ValueError):
outputs.add_route(Tags.CATEGORICAL)
outputs.add_route(Tags.CATEGORICAL)
Comment on lines -33 to +34
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was changed to not raise an error because it is possible that there is no continuous features in the DLRM block. Alternatively, we could have a try-except in the DLRM block, but this felt more natural to me.


assert not outputs