From 3fd6980c45a9067f32e905e9a18c64e6544564cb Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Wed, 28 Jun 2023 12:19:34 +0200 Subject: [PATCH 1/7] First pass over DLRM-related blocks --- merlin/models/torch/blocks/dlrm.py | 135 +++++++++++++++++++++++++++ tests/unit/torch/blocks/test_dlrm.py | 11 +++ 2 files changed, 146 insertions(+) create mode 100644 merlin/models/torch/blocks/dlrm.py create mode 100644 tests/unit/torch/blocks/test_dlrm.py diff --git a/merlin/models/torch/blocks/dlrm.py b/merlin/models/torch/blocks/dlrm.py new file mode 100644 index 0000000000..3239147f81 --- /dev/null +++ b/merlin/models/torch/blocks/dlrm.py @@ -0,0 +1,135 @@ +from typing import Dict, Optional + +import torch +from torch import nn + +from merlin.models.torch.block import Block +from merlin.models.torch.inputs.embedding import EmbeddingTables +from merlin.models.torch.inputs.tabular import TabularInputBlock +from merlin.models.torch.link import Link +from merlin.models.torch.transforms.agg import MaybeAgg, Stack +from merlin.models.utils.doc_utils import docstring_parameter +from merlin.schema import Schema, Tags + +_DLRM_REF = """ + References + ---------- + .. [1] Naumov, Maxim, et al. "Deep learning recommendation model for + personalization and recommendation systems." arXiv preprint arXiv:1906.00091 (2019). +""" + + +@docstring_parameter(dlrm_reference=_DLRM_REF) +class DLRMInputBlock(TabularInputBlock): + """ "Input-block for DLRM model. + + Parameters + ---------- + schema : Schema, optional + The schema to use for selection. Default is None. + dim : int + The dimensionality of the output vectors. + bottom_block : Block + Block to pass the continuous features to. + Note that, the output dimensionality of this block must be equal to ``dim``. + + {dlrm_reference} + + Raises + ------ + ValueError + If no categorical input is provided in the schema. + + """ + + def __init__(self, schema: Schema, dim: int, bottom_block: Block): + super().__init__(schema) + self.add_route(Tags.CATEGORICAL, EmbeddingTables(dim, seq_combiner="mean")) + self.add_route(Tags.CONTINUOUS, bottom_block) + + if "categorical" not in self: + raise ValueError("DLRMInputBlock must have a categorical input") + + +@docstring_parameter(dlrm_reference=_DLRM_REF) +class DLRMInteraction(nn.Module): + """ + This class defines the forward interaction operation as proposed + in the DLRM + `paper https://arxiv.org/pdf/1906.00091.pdf`_ [1]_. + + This forward operation performs elementwise multiplication + of the embeddings followed by a reduction sum. + + {dlrm_reference} + + """ + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if not hasattr(self, "triu_indices"): + self.register_buffer( + "triu_indices", torch.triu_indices(inputs.shape[1], inputs.shape[1], offset=1) + ) + + interactions = torch.bmm(inputs, torch.transpose(inputs, 1, 2)) + interactions_flat = interactions[:, self.triu_indices[0], self.triu_indices[1]] + + return interactions_flat + + +class ShortcutConcatContinuous(Link): + """ + A shortcut connection that concatenates + continuous input features and intermediate outputs. + + When there's no continuous input, the intermediate output is returned. + """ + + def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: + intermediate_output = self.output(inputs) + + if "continuous" in inputs: + return torch.cat((inputs["continuous"], intermediate_output), dim=1) + + return intermediate_output + + +@docstring_parameter(dlrm_reference=_DLRM_REF) +class DLRMBlock(Block): + """Builds the DLRM architecture, as proposed in the following + `paper https://arxiv.org/pdf/1906.00091.pdf`_ [1]_. + + Parameters + ---------- + schema : Schema, optional + The schema to use for selection. Default is None. + dim : int + The dimensionality of the output vectors. + bottom_block : Block + Block to pass the continuous features to. + Note that, the output dimensionality of this block must be equal to ``dim``. + top_block : Block, optional + An optional upper-level block of the model. + interaction : nn.Module, default=DLRMInteraction() + Interaction module for DLRM. + + {dlrm_reference} + + Raises + ------ + ValueError + If no categorical input is provided in the schema. + """ + + def __init__( + self, + schema: Schema, + dim: int, + bottom_block: Block, + top_block: Optional[Block] = None, + interaction: nn.Module = DLRMInteraction(), + ): + super().__init__(DLRMInputBlock(schema, dim, bottom_block)) + self.append(Block(MaybeAgg(Stack()), interaction), link=ShortcutConcatContinuous()) + if top_block: + self.append(top_block) diff --git a/tests/unit/torch/blocks/test_dlrm.py b/tests/unit/torch/blocks/test_dlrm.py new file mode 100644 index 0000000000..f1a0d0035e --- /dev/null +++ b/tests/unit/torch/blocks/test_dlrm.py @@ -0,0 +1,11 @@ +import merlin.models.torch as mm +from merlin.models.torch.blocks.dlrm import ( + DLRMBlock, + DLRMInputBlock, + DLRMInteraction, + ShortcutConcatContinuous, +) + + +class TestDLRMInputBlock: + ... From 73a2b62aeefe5836bc9d597e02628f5a3d3baa6c Mon Sep 17 00:00:00 2001 From: edknv Date: Wed, 28 Jun 2023 22:10:37 -0700 Subject: [PATCH 2/7] add a test for dlrm block and fixes to make test pass --- merlin/models/torch/__init__.py | 10 ++++++++ merlin/models/torch/blocks/dlrm.py | 2 +- merlin/models/torch/transforms/agg.py | 4 ++++ tests/unit/torch/blocks/test_dlrm.py | 33 ++++++++++++++++++++++----- 4 files changed, 42 insertions(+), 7 deletions(-) diff --git a/merlin/models/torch/__init__.py b/merlin/models/torch/__init__.py index eef6f9ece6..35c9ab8099 100644 --- a/merlin/models/torch/__init__.py +++ b/merlin/models/torch/__init__.py @@ -17,6 +17,12 @@ from merlin.models.torch import schema from merlin.models.torch.batch import Batch, Sequence from merlin.models.torch.block import Block, ParallelBlock +from merlin.models.torch.blocks.dlrm import ( + DLRMBlock, + DLRMInputBlock, + DLRMInteraction, + ShortcutConcatContinuous, +) from merlin.models.torch.blocks.mlp import MLPBlock from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys @@ -51,4 +57,8 @@ "Concat", "Stack", "schema", + "DLRMBlock", + "DLRMInputBlock", + "DLRMInteraction", + "ShortcutConcatContinuous", ] diff --git a/merlin/models/torch/blocks/dlrm.py b/merlin/models/torch/blocks/dlrm.py index 3239147f81..37dd80c2ee 100644 --- a/merlin/models/torch/blocks/dlrm.py +++ b/merlin/models/torch/blocks/dlrm.py @@ -130,6 +130,6 @@ def __init__( interaction: nn.Module = DLRMInteraction(), ): super().__init__(DLRMInputBlock(schema, dim, bottom_block)) - self.append(Block(MaybeAgg(Stack()), interaction), link=ShortcutConcatContinuous()) + self.append(Block(MaybeAgg(Stack(dim=1)), interaction), link=ShortcutConcatContinuous()) if top_block: self.append(top_block) diff --git a/merlin/models/torch/transforms/agg.py b/merlin/models/torch/transforms/agg.py index f6dcf457d6..4979b1dc46 100644 --- a/merlin/models/torch/transforms/agg.py +++ b/merlin/models/torch/transforms/agg.py @@ -104,6 +104,9 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: if self.align_dims: max_dims = max(tensor.dim() for tensor in sorted_tensors) + max_dims = max( + max_dims, 2 + ) # assume first dimension is batch size + at least one feature _sorted_tensors = [] for tensor in sorted_tensors: if tensor.dim() < max_dims: @@ -140,6 +143,7 @@ class Stack(AggModule): dim : int The dimension along which the tensors will be stacked. Default is 0. + unsqeeze: bool Examples -------- diff --git a/tests/unit/torch/blocks/test_dlrm.py b/tests/unit/torch/blocks/test_dlrm.py index f1a0d0035e..6d9a8fd084 100644 --- a/tests/unit/torch/blocks/test_dlrm.py +++ b/tests/unit/torch/blocks/test_dlrm.py @@ -1,11 +1,32 @@ import merlin.models.torch as mm -from merlin.models.torch.blocks.dlrm import ( - DLRMBlock, - DLRMInputBlock, - DLRMInteraction, - ShortcutConcatContinuous, -) +from merlin.models.torch.batch import sample_batch +from merlin.models.torch.utils import module_utils +from merlin.schema import Tags class TestDLRMInputBlock: ... + + +class DLRMInteraction: + ... + + +class TestDLRMBlock: + def test_basic(self, testing_data): + schema = testing_data.schema + embedding_dim = 64 + block = mm.DLRMBlock( + schema, + dim=embedding_dim, + bottom_block=mm.MLPBlock([embedding_dim]), + ) + batch_size = 16 + batch = sample_batch(testing_data, batch_size=batch_size) + block.to(device=batch.device()) # TODO: move this + + outputs = module_utils.module_test(block, batch.features) + + num_features = len(schema.select_by_tag(Tags.CATEGORICAL)) + 1 + dot_product_dim = (num_features - 1) * num_features // 2 + assert list(outputs.shape) == [batch_size, dot_product_dim + embedding_dim] From c9d43da72e1fc1d90b6f91dcea9a5a2493148443 Mon Sep 17 00:00:00 2001 From: edknv Date: Thu, 29 Jun 2023 20:23:11 -0700 Subject: [PATCH 3/7] Add unit tests --- merlin/models/torch/__init__.py | 10 +-- merlin/models/torch/blocks/dlrm.py | 5 +- merlin/models/torch/router.py | 3 + merlin/models/torch/transforms/agg.py | 1 - tests/unit/torch/blocks/test_dlrm.py | 104 +++++++++++++++++++++++--- 5 files changed, 100 insertions(+), 23 deletions(-) diff --git a/merlin/models/torch/__init__.py b/merlin/models/torch/__init__.py index 35c9ab8099..d2326af5e9 100644 --- a/merlin/models/torch/__init__.py +++ b/merlin/models/torch/__init__.py @@ -17,12 +17,7 @@ from merlin.models.torch import schema from merlin.models.torch.batch import Batch, Sequence from merlin.models.torch.block import Block, ParallelBlock -from merlin.models.torch.blocks.dlrm import ( - DLRMBlock, - DLRMInputBlock, - DLRMInteraction, - ShortcutConcatContinuous, -) +from merlin.models.torch.blocks.dlrm import DLRMBlock from merlin.models.torch.blocks.mlp import MLPBlock from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys @@ -58,7 +53,4 @@ "Stack", "schema", "DLRMBlock", - "DLRMInputBlock", - "DLRMInteraction", - "ShortcutConcatContinuous", ] diff --git a/merlin/models/torch/blocks/dlrm.py b/merlin/models/torch/blocks/dlrm.py index 37dd80c2ee..6d4bfa9d07 100644 --- a/merlin/models/torch/blocks/dlrm.py +++ b/merlin/models/torch/blocks/dlrm.py @@ -21,7 +21,7 @@ @docstring_parameter(dlrm_reference=_DLRM_REF) class DLRMInputBlock(TabularInputBlock): - """ "Input-block for DLRM model. + """Input block for DLRM model. Parameters ---------- @@ -130,6 +130,9 @@ def __init__( interaction: nn.Module = DLRMInteraction(), ): super().__init__(DLRMInputBlock(schema, dim, bottom_block)) + + # link = ShortcutConcatContinuous() if "continuous" in self[0] else None self.append(Block(MaybeAgg(Stack(dim=1)), interaction), link=ShortcutConcatContinuous()) + if top_block: self.append(top_block) diff --git a/merlin/models/torch/router.py b/merlin/models/torch/router.py index 29126e0f91..326064c68c 100644 --- a/merlin/models/torch/router.py +++ b/merlin/models/torch/router.py @@ -88,6 +88,9 @@ def add_route( """ routing_module = schema.select(self.selectable, selection) + if not routing_module: + return self + if module is not None: schema.setup_schema(module, routing_module.schema) diff --git a/merlin/models/torch/transforms/agg.py b/merlin/models/torch/transforms/agg.py index 4979b1dc46..552fcf1d32 100644 --- a/merlin/models/torch/transforms/agg.py +++ b/merlin/models/torch/transforms/agg.py @@ -143,7 +143,6 @@ class Stack(AggModule): dim : int The dimension along which the tensors will be stacked. Default is 0. - unsqeeze: bool Examples -------- diff --git a/tests/unit/torch/blocks/test_dlrm.py b/tests/unit/torch/blocks/test_dlrm.py index 6d9a8fd084..7f4a712f29 100644 --- a/tests/unit/torch/blocks/test_dlrm.py +++ b/tests/unit/torch/blocks/test_dlrm.py @@ -1,32 +1,112 @@ +import math + +import pytest +import torch + import merlin.models.torch as mm from merlin.models.torch.batch import sample_batch +from merlin.models.torch.blocks.dlrm import DLRMInputBlock, DLRMInteraction from merlin.models.torch.utils import module_utils from merlin.schema import Tags class TestDLRMInputBlock: - ... + def test_routes_and_output_shapes(self, testing_data): + schema = testing_data.schema + embedding_dim = 64 + block = DLRMInputBlock(schema, embedding_dim, mm.MLPBlock([embedding_dim])) + + assert isinstance(block["categorical"], mm.EmbeddingTables) + assert len(block["categorical"]) == len(schema.select_by_tag(Tags.CATEGORICAL)) + + assert isinstance(block["continuous"][0], mm.SelectKeys) + assert isinstance(block["continuous"][1], mm.MLPBlock) + + batch_size = 16 + batch = sample_batch(testing_data, batch_size=batch_size) + + outputs = module_utils.module_test(block, batch.features) + + for col in schema.select_by_tag(Tags.CATEGORICAL): + assert outputs[col.name].shape == (batch_size, embedding_dim) + assert outputs["continuous"].shape == (batch_size, embedding_dim) -class DLRMInteraction: - ... +class TestDLRMInteraction: + @pytest.mark.parametrize( + "batch_size,num_features,dim", + [(16, 3, 3), (32, 5, 8), (64, 5, 4)], + ) + def test_output_shape(self, batch_size, num_features, dim): + module = DLRMInteraction() + inputs = torch.rand((batch_size, num_features, dim)) + outputs = module_utils.module_test(module, inputs) + + assert outputs.shape == (batch_size, num_features - 1 + math.comb(num_features - 1, 2)) class TestDLRMBlock: - def test_basic(self, testing_data): - schema = testing_data.schema + @pytest.fixture(autouse=True) + def setup_method(self, testing_data): + self.schema = testing_data.schema + self.batch_size = 16 + self.batch = sample_batch(testing_data, batch_size=self.batch_size) + + def test_dlrm_output_shape(self): embedding_dim = 64 block = mm.DLRMBlock( - schema, + self.schema, dim=embedding_dim, bottom_block=mm.MLPBlock([embedding_dim]), ) - batch_size = 16 - batch = sample_batch(testing_data, batch_size=batch_size) - block.to(device=batch.device()) # TODO: move this - outputs = module_utils.module_test(block, batch.features) + outputs = module_utils.module_test(block, self.batch.features) + + num_features = len(self.schema.select_by_tag(Tags.CATEGORICAL)) + 1 + dot_product_dim = (num_features - 1) * num_features // 2 + assert list(outputs.shape) == [self.batch_size, dot_product_dim + embedding_dim] + + def test_dlrm_with_top_block(self): + embedding_dim = 32 + top_block_dim = 8 + block = mm.DLRMBlock( + self.schema, + dim=embedding_dim, + bottom_block=mm.MLPBlock([embedding_dim]), + top_block=mm.MLPBlock([top_block_dim]), + ) + + outputs = module_utils.module_test(block, self.batch.features) + + assert list(outputs.shape) == [self.batch_size, top_block_dim] + + def test_dlrm_block_no_categorical_features(self): + schema = self.schema.remove_by_tag(Tags.CATEGORICAL) + embedding_dim = 32 + + with pytest.raises(ValueError, match="must have a categorical input"): + _ = mm.DLRMBlock( + schema, + dim=embedding_dim, + bottom_block=mm.MLPBlock([embedding_dim]), + ) + + def test_dlrm_block_no_continuous_features(self): + embedding_dim = 32 + block = mm.DLRMBlock( + self.schema.remove_by_tag(Tags.CONTINUOUS), + dim=embedding_dim, + bottom_block=mm.MLPBlock([embedding_dim]), + ) + continuous_features = [col.name for col in self.schema.select_by_tag(Tags.CONTINUOUS)] + inputs = { + name: self.batch.features[name] + for name in sorted(self.batch.features) + if name not in continuous_features + } + + outputs = module_utils.module_test(block, inputs) - num_features = len(schema.select_by_tag(Tags.CATEGORICAL)) + 1 + num_features = len(self.schema.select_by_tag(Tags.CATEGORICAL)) dot_product_dim = (num_features - 1) * num_features // 2 - assert list(outputs.shape) == [batch_size, dot_product_dim + embedding_dim] + assert list(outputs.shape) == [self.batch_size, dot_product_dim] From 790835be18775a79ad3416a68e0639b743f43879 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Fri, 30 Jun 2023 09:42:09 +0200 Subject: [PATCH 4/7] Initialize default metrics/loss inside ModelOutput instead --- merlin/models/torch/outputs/classification.py | 16 +++++++--------- merlin/models/torch/outputs/regression.py | 11 +++++++---- tests/unit/torch/outputs/test_classification.py | 4 ++-- tests/unit/torch/outputs/test_regression.py | 2 +- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/merlin/models/torch/outputs/classification.py b/merlin/models/torch/outputs/classification.py index e201953fd0..2ca36143f7 100644 --- a/merlin/models/torch/outputs/classification.py +++ b/merlin/models/torch/outputs/classification.py @@ -36,24 +36,22 @@ class BinaryOutput(ModelOutput): The metrics used for evaluation. Default includes Accuracy, AUROC, Precision, and Recall. """ + DEFAULT_LOSS_CLS = nn.BCEWithLogitsLoss + DEFAULT_METRICS_CLS = (Accuracy, AUROC, Precision, Recall) + def __init__( self, schema: Optional[ColumnSchema] = None, - loss: nn.Module = nn.BCEWithLogitsLoss(), - metrics: Sequence[Metric] = ( - Accuracy(task="binary"), - AUROC(task="binary"), - Precision(task="binary"), - Recall(task="binary"), - ), + loss: Optional[nn.Module] = None, + metrics: Sequence[Metric] = (), ): """Initializes a BinaryOutput object.""" super().__init__( nn.LazyLinear(1), nn.Sigmoid(), schema=schema, - loss=loss, - metrics=metrics, + loss=loss or self.DEFAULT_LOSS_CLS(), + metrics=metrics or [m(task="binary") for m in self.DEFAULT_METRICS_CLS], ) def setup_schema(self, target: Optional[Union[ColumnSchema, Schema]]): diff --git a/merlin/models/torch/outputs/regression.py b/merlin/models/torch/outputs/regression.py index 0f9f9ad318..e3b2f97b09 100644 --- a/merlin/models/torch/outputs/regression.py +++ b/merlin/models/torch/outputs/regression.py @@ -36,18 +36,21 @@ class RegressionOutput(ModelOutput): The metrics used for evaluation. Default is MeanSquaredError. """ + DEFAULT_LOSS_CLS = nn.MSELoss + DEFAULT_METRICS_CLS = (MeanSquaredError,) + def __init__( self, schema: Optional[ColumnSchema] = None, - loss: nn.Module = nn.MSELoss(), - metrics: Sequence[Metric] = (MeanSquaredError(),), + loss: Optional[nn.Module] = None, + metrics: Sequence[Metric] = (), ): """Initializes a RegressionOutput object.""" super().__init__( nn.LazyLinear(1), schema=schema, - loss=loss, - metrics=metrics, + loss=loss or self.DEFAULT_LOSS_CLS(), + metrics=metrics or [m() for m in self.DEFAULT_METRICS_CLS], ) def setup_schema(self, target: Optional[Union[ColumnSchema, Schema]]): diff --git a/tests/unit/torch/outputs/test_classification.py b/tests/unit/torch/outputs/test_classification.py index ea6643c740..755d465350 100644 --- a/tests/unit/torch/outputs/test_classification.py +++ b/tests/unit/torch/outputs/test_classification.py @@ -31,12 +31,12 @@ def test_init(self): assert isinstance(binary_output, mm.BinaryOutput) assert isinstance(binary_output.loss, nn.BCEWithLogitsLoss) - assert binary_output.metrics == ( + assert binary_output.metrics == [ Accuracy(task="binary"), AUROC(task="binary"), Precision(task="binary"), Recall(task="binary"), - ) + ] assert binary_output.output_schema == Schema() def test_identity(self): diff --git a/tests/unit/torch/outputs/test_regression.py b/tests/unit/torch/outputs/test_regression.py index 17541c5081..f8537bca51 100644 --- a/tests/unit/torch/outputs/test_regression.py +++ b/tests/unit/torch/outputs/test_regression.py @@ -30,7 +30,7 @@ def test_init(self): assert isinstance(reg_output, mm.RegressionOutput) assert isinstance(reg_output.loss, nn.MSELoss) - assert reg_output.metrics == (MeanSquaredError,) + assert reg_output.metrics == [MeanSquaredError()] assert reg_output.output_schema == Schema() def test_identity(self): From 03a60b5db9cc9f12b33bd6084650a272026d1d03 Mon Sep 17 00:00:00 2001 From: edknv <109497216+edknv@users.noreply.github.com> Date: Fri, 30 Jun 2023 01:18:52 -0700 Subject: [PATCH 5/7] Update merlin/models/torch/blocks/dlrm.py Co-authored-by: Radek Osmulski --- merlin/models/torch/blocks/dlrm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/merlin/models/torch/blocks/dlrm.py b/merlin/models/torch/blocks/dlrm.py index 6d4bfa9d07..90cd860069 100644 --- a/merlin/models/torch/blocks/dlrm.py +++ b/merlin/models/torch/blocks/dlrm.py @@ -59,7 +59,7 @@ class DLRMInteraction(nn.Module): `paper https://arxiv.org/pdf/1906.00091.pdf`_ [1]_. This forward operation performs elementwise multiplication - of the embeddings followed by a reduction sum. + followed by a reduction sum (equivalent to a dot product) of all embedding pairs. {dlrm_reference} From 26ad78ae5024449b6825712ca9ad4c67b3608c80 Mon Sep 17 00:00:00 2001 From: edknv Date: Fri, 30 Jun 2023 20:12:13 -0700 Subject: [PATCH 6/7] fixes and changes for failing tests --- merlin/models/torch/blocks/dlrm.py | 11 +++++++---- tests/unit/torch/models/test_base.py | 6 +++--- tests/unit/torch/outputs/test_tabular.py | 6 ++++-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/merlin/models/torch/blocks/dlrm.py b/merlin/models/torch/blocks/dlrm.py index 90cd860069..a24e4d1f71 100644 --- a/merlin/models/torch/blocks/dlrm.py +++ b/merlin/models/torch/blocks/dlrm.py @@ -110,8 +110,9 @@ class DLRMBlock(Block): Note that, the output dimensionality of this block must be equal to ``dim``. top_block : Block, optional An optional upper-level block of the model. - interaction : nn.Module, default=DLRMInteraction() + interaction : nn.Module, optional Interaction module for DLRM. + If not provided, DLRMInteraction will be used by default. {dlrm_reference} @@ -127,12 +128,14 @@ def __init__( dim: int, bottom_block: Block, top_block: Optional[Block] = None, - interaction: nn.Module = DLRMInteraction(), + interaction: Optional[nn.Module] = None, ): super().__init__(DLRMInputBlock(schema, dim, bottom_block)) - # link = ShortcutConcatContinuous() if "continuous" in self[0] else None - self.append(Block(MaybeAgg(Stack(dim=1)), interaction), link=ShortcutConcatContinuous()) + self.append( + Block(MaybeAgg(Stack(dim=1)), interaction or DLRMInteraction()), + link=ShortcutConcatContinuous(), + ) if top_block: self.append(top_block) diff --git a/tests/unit/torch/models/test_base.py b/tests/unit/torch/models/test_base.py index cedd3e6ff6..ab329b8ca1 100644 --- a/tests/unit/torch/models/test_base.py +++ b/tests/unit/torch/models/test_base.py @@ -133,11 +133,11 @@ def test_training_step_with_dataloader(self): mm.BinaryOutput(ColumnSchema("target")), ) - feature = [[1.0, 2.0], [3.0, 4.0]] - target = [[0.0], [1.0]] + feature = [2.0, 3.0] + target = [0.0, 1.0] dataset = Dataset(pd.DataFrame({"feature": feature, "target": target})) - with Loader(dataset, batch_size=1) as loader: + with Loader(dataset, batch_size=2) as loader: model.initialize(loader) batch = loader.peek() diff --git a/tests/unit/torch/outputs/test_tabular.py b/tests/unit/torch/outputs/test_tabular.py index b3dd2abe4b..22ae132735 100644 --- a/tests/unit/torch/outputs/test_tabular.py +++ b/tests/unit/torch/outputs/test_tabular.py @@ -29,6 +29,8 @@ def test_exceptions(self): with pytest.raises(ValueError, match="not found"): mm.TabularOutputBlock(self.schema, init="not_found") + def test_no_route_for_non_existent_tag(self): outputs = mm.TabularOutputBlock(self.schema) - with pytest.raises(ValueError): - outputs.add_route(Tags.CATEGORICAL) + outputs.add_route(Tags.CATEGORICAL) + + assert not outputs From b5c0e3f362c72c25501aebca8ae77814f818c561 Mon Sep 17 00:00:00 2001 From: edknv Date: Fri, 30 Jun 2023 20:33:32 -0700 Subject: [PATCH 7/7] pass batch to model test --- tests/unit/torch/blocks/test_dlrm.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/unit/torch/blocks/test_dlrm.py b/tests/unit/torch/blocks/test_dlrm.py index 7f4a712f29..21d65e8561 100644 --- a/tests/unit/torch/blocks/test_dlrm.py +++ b/tests/unit/torch/blocks/test_dlrm.py @@ -25,7 +25,7 @@ def test_routes_and_output_shapes(self, testing_data): batch_size = 16 batch = sample_batch(testing_data, batch_size=batch_size) - outputs = module_utils.module_test(block, batch.features) + outputs = module_utils.module_test(block, batch) for col in schema.select_by_tag(Tags.CATEGORICAL): assert outputs[col.name].shape == (batch_size, embedding_dim) @@ -60,7 +60,7 @@ def test_dlrm_output_shape(self): bottom_block=mm.MLPBlock([embedding_dim]), ) - outputs = module_utils.module_test(block, self.batch.features) + outputs = module_utils.module_test(block, self.batch) num_features = len(self.schema.select_by_tag(Tags.CATEGORICAL)) + 1 dot_product_dim = (num_features - 1) * num_features // 2 @@ -76,7 +76,7 @@ def test_dlrm_with_top_block(self): top_block=mm.MLPBlock([top_block_dim]), ) - outputs = module_utils.module_test(block, self.batch.features) + outputs = module_utils.module_test(block, self.batch) assert list(outputs.shape) == [self.batch_size, top_block_dim] @@ -91,22 +91,22 @@ def test_dlrm_block_no_categorical_features(self): bottom_block=mm.MLPBlock([embedding_dim]), ) - def test_dlrm_block_no_continuous_features(self): + def test_dlrm_block_no_continuous_features(self, testing_data): + schema = testing_data.schema.remove_by_tag(Tags.CONTINUOUS) + testing_data.schema = schema + embedding_dim = 32 block = mm.DLRMBlock( - self.schema.remove_by_tag(Tags.CONTINUOUS), + schema, dim=embedding_dim, bottom_block=mm.MLPBlock([embedding_dim]), ) - continuous_features = [col.name for col in self.schema.select_by_tag(Tags.CONTINUOUS)] - inputs = { - name: self.batch.features[name] - for name in sorted(self.batch.features) - if name not in continuous_features - } - outputs = module_utils.module_test(block, inputs) + batch_size = 16 + batch = sample_batch(testing_data, batch_size=batch_size) + + outputs = module_utils.module_test(block, batch) - num_features = len(self.schema.select_by_tag(Tags.CATEGORICAL)) + num_features = len(schema.select_by_tag(Tags.CATEGORICAL)) dot_product_dim = (num_features - 1) * num_features // 2 - assert list(outputs.shape) == [self.batch_size, dot_product_dim] + assert list(outputs.shape) == [batch_size, dot_product_dim]