diff --git a/merlin/models/torch/__init__.py b/merlin/models/torch/__init__.py index 293be536e2..153b88221e 100644 --- a/merlin/models/torch/__init__.py +++ b/merlin/models/torch/__init__.py @@ -23,7 +23,7 @@ from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys from merlin.models.torch.inputs.tabular import TabularInputBlock from merlin.models.torch.models.base import Model -from merlin.models.torch.models.ranking import DLRMModel +from merlin.models.torch.models.ranking import DCNModel, DLRMModel from merlin.models.torch.outputs.base import ModelOutput from merlin.models.torch.outputs.classification import ( BinaryOutput, @@ -75,4 +75,5 @@ "target_schema", "DLRMBlock", "DLRMModel", + "DCNModel", ] diff --git a/merlin/models/torch/blocks/cross.py b/merlin/models/torch/blocks/cross.py index ebfa215bcf..c395fc8440 100644 --- a/merlin/models/torch/blocks/cross.py +++ b/merlin/models/torch/blocks/cross.py @@ -5,6 +5,7 @@ from torch import nn from torch.nn.modules.lazy import LazyModuleMixin +from merlin.models.torch.batch import Batch from merlin.models.torch.block import Block from merlin.models.torch.transforms.agg import Concat from merlin.models.utils.doc_utils import docstring_parameter @@ -127,7 +128,9 @@ def with_low_rank(cls, depth: int, low_rank: nn.Module) -> "CrossBlock": return cls(*(Block(deepcopy(low_rank), *block) for block in cls.with_depth(depth))) - def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: + def forward( + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None + ) -> torch.Tensor: """Forward-pass of the cross-block. Parameters diff --git a/merlin/models/torch/models/ranking.py b/merlin/models/torch/models/ranking.py index 292abebbd8..d265f2ef7c 100644 --- a/merlin/models/torch/models/ranking.py +++ b/merlin/models/torch/models/ranking.py @@ -2,13 +2,19 @@ from torch import nn -from merlin.models.torch.block import Block -from merlin.models.torch.blocks.dlrm import DLRMBlock +from merlin.models.torch.block import Block, ParallelBlock +from merlin.models.torch.blocks.cross import _DCNV2_REF, CrossBlock +from merlin.models.torch.blocks.dlrm import _DLRM_REF, DLRMBlock +from merlin.models.torch.blocks.mlp import MLPBlock +from merlin.models.torch.inputs.tabular import TabularInputBlock from merlin.models.torch.models.base import Model from merlin.models.torch.outputs.tabular import TabularOutputBlock +from merlin.models.torch.transforms.agg import Concat, MaybeAgg +from merlin.models.utils.doc_utils import docstring_parameter from merlin.schema import Schema +@docstring_parameter(dlrm_reference=_DLRM_REF) class DLRMModel(Model): """ The Deep Learning Recommendation Model (DLRM) as proposed in Naumov, et al. [1] @@ -42,15 +48,13 @@ class DLRMModel(Model): ... schema, ... dim=64, ... bottom_block=mm.MLPBlock([256, 64]), - ... output_block=BinaryOutput(ColumnSchema("target"))) + ... output_block=mm.BinaryOutput(ColumnSchema("target")), + ... ) >>> trainer = pl.Trainer() >>> model.initialize(dataloader) >>> trainer.fit(model, dataloader) - References - ---------- - [1] Naumov, Maxim, et al. "Deep learning recommendation model for - personalization and recommendation systems." arXiv preprint arXiv:1906.00091 (2019). + {dlrm_reference} """ def __init__( @@ -74,3 +78,74 @@ def __init__( ) super().__init__(dlrm_body, output_block) + + +@docstring_parameter(dcn_reference=_DCNV2_REF) +class DCNModel(Model): + """ + The Deep & Cross Network (DCN) architecture as proposed in Wang, et al. [1] + + Parameters + ---------- + schema : Schema + The schema to use for selection. + depth : int, optional + Number of cross-layers to be stacked, by default 1 + deep_block : Block, optional + The `Block` to use as the deep part of the model (typically a `MLPBlock`) + stacked : bool + Whether to use the stacked version of the model or the parallel version. + input_block : Block, optional + The `Block` to use as the input layer. If None, a default `TabularInputBlock` object + is instantiated, that creates the embedding tables for the categorical features + based on the schema. The embedding dimensions are inferred from the features + cardinality. For a custom representation of input data you can instantiate + and provide a `TabularInputBlock` instance. + + Returns + ------- + Model + An instance of Model class representing the fully formed DCN. + + Example usage + ------------- + >>> model = mm.DCNModel( + ... schema, + ... depth=2, + ... deep_block=mm.MLPBlock([256, 64]), + ... output_block=mm.BinaryOutput(ColumnSchema("target")), + ... ) + >>> trainer = pl.Trainer() + >>> model.initialize(dataloader) + >>> trainer.fit(model, dataloader) + + {dcn_reference} + """ + + def __init__( + self, + schema: Schema, + depth: int = 1, + deep_block: Optional[Block] = None, + stacked: bool = True, + input_block: Optional[Block] = None, + output_block: Optional[Block] = None, + ) -> None: + if input_block is None: + input_block = TabularInputBlock(schema, init="defaults") + + if output_block is None: + output_block = TabularOutputBlock(schema, init="defaults") + + if deep_block is None: + deep_block = MLPBlock([512, 256]) + + if stacked: + cross_network = Block(CrossBlock.with_depth(depth), deep_block) + else: + cross_network = Block( + ParallelBlock({"cross": CrossBlock.with_depth(depth), "deep": deep_block}), + MaybeAgg(Concat()), + ) + + super().__init__(input_block, *cross_network, output_block) diff --git a/tests/unit/torch/models/test_ranking.py b/tests/unit/torch/models/test_ranking.py index 0fb463e0ef..1a9e5678e2 100644 --- a/tests/unit/torch/models/test_ranking.py +++ b/tests/unit/torch/models/test_ranking.py @@ -36,3 +36,34 @@ def test_train_dlrm_with_lightning_loader( batch = sample_batch(music_streaming_data, batch_size) _ = module_utils.module_test(model, batch) + + +class TestDCNModel: + @pytest.mark.parametrize("depth", [1, 2]) + @pytest.mark.parametrize("stacked", [True, False]) + @pytest.mark.parametrize("deep_block", [None, mm.MLPBlock([4, 2])]) + def test_train_dcn_with_lightning_trainer( + self, + music_streaming_data, + depth, + stacked, + deep_block, + batch_size=16, + ): + schema = music_streaming_data.schema.select_by_name( + ["item_id", "user_id", "user_age", "item_genres", "click"] + ) + music_streaming_data.schema = schema + + model = mm.DCNModel(schema, depth=depth, deep_block=deep_block, stacked=stacked) + + trainer = pl.Trainer(max_epochs=1, devices=1) + + with Loader(music_streaming_data, batch_size=batch_size) as train_loader: + model.initialize(train_loader) + trainer.fit(model, train_loader) + + assert trainer.logged_metrics["train_loss"] > 0.0 + + batch = sample_batch(music_streaming_data, batch_size) + _ = module_utils.module_test(model, batch)