Skip to content

Commit

Permalink
Adding sequence-defaults in TabularInputBlock (#1200)
Browse files Browse the repository at this point in the history
* Adding sequence-defaults in TabularInputBlock

* First pass over stack_context

* Fixing failing tests
  • Loading branch information
marcromeyn committed Jul 12, 2023
1 parent 8b1ff94 commit 28b690c
Show file tree
Hide file tree
Showing 14 changed files with 225 additions and 43 deletions.
3 changes: 2 additions & 1 deletion merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from merlin.models.torch.blocks.mlp import MLPBlock
from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables
from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys
from merlin.models.torch.inputs.tabular import TabularInputBlock
from merlin.models.torch.inputs.tabular import TabularInputBlock, stack_context
from merlin.models.torch.models.base import Model, MultiLoader
from merlin.models.torch.models.ranking import DCNModel, DLRMModel
from merlin.models.torch.outputs.base import ModelOutput
Expand Down Expand Up @@ -101,4 +101,5 @@
"EncoderBlock",
"DaskEncoder",
"DaskPredictor",
"stack_context",
]
8 changes: 7 additions & 1 deletion merlin/models/torch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,14 @@ def flatten_as_dict(self, inputs: Optional["Batch"]) -> Dict[str, torch.Tensor]:

if torch.jit.isinstance(inputs, Batch) and inputs is not self:
_input_dict: Dict[str, torch.Tensor] = inputs._flatten()
for key in _input_dict:
for key, val in _input_dict.items():
flat_dict["inputs." + key] = dummy_tensor
if (
not key.endswith("__values")
and not key.endswith("__offsets")
and key not in flat_dict
):
flat_dict[key] = val

return flat_dict

Expand Down
12 changes: 9 additions & 3 deletions merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def repeat(self, n: int = 1, name=None) -> "Block":
def repeat_parallel(self, n: int = 1, name=None) -> "ParallelBlock":
return repeat_parallel(self, n, name=name)

def repeat_parallel_like(self, like: HasKeys, name=None) -> "ParallelBlock":
return repeat_parallel_like(self, like, name=name)
def repeat_parallel_like(self, like: HasKeys, agg=None) -> "ParallelBlock":
return repeat_parallel_like(self, like, agg=agg)

def copy(self) -> "Block":
"""
Expand Down Expand Up @@ -718,7 +718,13 @@ def repeat_parallel(module: nn.Module, n: int = 1, agg=None) -> ParallelBlock:

def repeat_parallel_like(module: nn.Module, like: HasKeys, agg=None) -> ParallelBlock:
branches = {}
for i, key in enumerate(like.keys()):

if isinstance(like, Schema):
keys = like.column_names
else:
keys = list(like.keys())

for i, key in enumerate(keys):
if i == 0:
branches[str(key)] = module
else:
Expand Down
11 changes: 4 additions & 7 deletions merlin/models/torch/blocks/dlrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,10 @@ class DLRMInputBlock(TabularInputBlock):
"""

def __init__(self, schema: Schema, dim: int, bottom_block: Block):
def __init__(self, schema: Optional[Schema], dim: int, bottom_block: Block):
super().__init__(schema)
self.add_route(Tags.CATEGORICAL, EmbeddingTables(dim, seq_combiner="mean"))
self.add_route(Tags.CONTINUOUS, bottom_block)

if "categorical" not in self:
raise ValueError("DLRMInputBlock must have a categorical input")
self.add_route(Tags.CONTINUOUS, bottom_block, required=False)


@docstring_parameter(dlrm_reference=_DLRM_REF)
Expand Down Expand Up @@ -117,7 +114,7 @@ class DLRMBlock(Block):
Parameters
----------
schema : Schema, optional
The schema to use for selection. Default is None.
The schema to use for selection.
dim : int
The dimensionality of the output vectors.
bottom_block : Block
Expand All @@ -139,7 +136,7 @@ class DLRMBlock(Block):

def __init__(
self,
schema: Schema,
schema: Optional[Schema],
dim: int,
bottom_block: Block,
top_block: Optional[Block] = None,
Expand Down
79 changes: 72 additions & 7 deletions merlin/models/torch/inputs/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
# limitations under the License.
#

from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Sequence, Union

from torch import nn

from merlin.models.torch.block import Block
from merlin.models.torch.inputs.embedding import EmbeddingTables
from merlin.models.torch.router import RouterBlock
from merlin.models.torch.schema import Selection, select, select_union
from merlin.models.torch.transforms.sequences import BroadcastToSequence
from merlin.models.utils.registry import Registry
from merlin.schema import Schema, Tags

Expand Down Expand Up @@ -56,8 +58,9 @@ def __init__(
agg: Optional[Union[str, nn.Module]] = None,
):
self.init = init
self.agg = agg
super().__init__(schema)
if agg:
self.append(Block.parse(agg))

def initialize_from_schema(self, schema: Schema):
super().initialize_from_schema(schema)
Expand All @@ -69,8 +72,6 @@ def initialize_from_schema(self, schema: Schema):
raise ValueError(f"Initializer {self.init} not found.")

self.init(self)
if self.agg:
self.append(Block.parse(self.agg))

@classmethod
def register_init(cls, name: str):
Expand All @@ -96,7 +97,7 @@ def defaults(block: TabularInputBlock):


@TabularInputBlock.register_init("defaults")
def defaults(block: TabularInputBlock):
def defaults(block: TabularInputBlock, seq_combiner="mean"):
"""
Default initializer function for a TabularInputBlock.
Expand All @@ -106,5 +107,69 @@ def defaults(block: TabularInputBlock):
Args:
block (TabularInputBlock): The block to initialize.
"""
block.add_route(Tags.CONTINUOUS)
block.add_route(Tags.CATEGORICAL, EmbeddingTables(seq_combiner="mean"))
block.add_route(Tags.CONTINUOUS, required=False)
block.add_route(Tags.CATEGORICAL, EmbeddingTables(seq_combiner=seq_combiner))


@TabularInputBlock.register_init("defaults-no-combiner")
def defaults_no_combiner(block: TabularInputBlock):
return defaults(block, seq_combiner=None)


@TabularInputBlock.register_init("broadcast-context")
def defaults_broadcast_to_seq(
block: TabularInputBlock,
seq_selection: Selection = Tags.SEQUENCE,
feature_selection: Sequence[Selection] = (Tags.CATEGORICAL, Tags.CONTINUOUS),
):
context_selection = _not_seq(seq_selection, feature_selection=feature_selection)
block.add_route(context_selection, TabularInputBlock(init="defaults"), name="context")
block.add_route(
seq_selection,
TabularInputBlock(init="defaults-no-combiner"),
name="sequence",
)
block.append(BroadcastToSequence(context_selection, seq_selection, block.schema))


def stack_context(
model_dim: int,
seq_selection: Selection = Tags.SEQUENCE,
projection_activation=None,
feature_selection: Sequence[Selection] = (Tags.CATEGORICAL, Tags.CONTINUOUS),
):
def init_stacked_context(block: TabularInputBlock):
import merlin.models.torch as mm

mlp_kwargs = {"units": [model_dim], "activation": projection_activation}
context_selection = _not_seq(seq_selection, feature_selection=feature_selection)
context = TabularInputBlock(select(block.schema, context_selection))
context.add_route(Tags.CATEGORICAL, EmbeddingTables(seq_combiner=None))
context.add_route(Tags.CONTINUOUS, mm.MLPBlock(**mlp_kwargs))
context["categorical"].append_for_each(mm.MLPBlock(**mlp_kwargs))
context.append(mm.Stack(dim=1))

block.add_route(context.schema, context, name="context")
block.add_route(
seq_selection,
TabularInputBlock(init="defaults-no-combiner", agg=mm.Concat(dim=2)),
name="sequence",
)

return init_stacked_context


def _not_seq(
seq_selection: Sequence[Selection],
feature_selection: Sequence[Selection] = (Tags.CATEGORICAL, Tags.CONTINUOUS),
) -> Selection:
if not isinstance(seq_selection, (tuple, list)):
seq_selection = (seq_selection,)

def select_non_seq(schema: Schema) -> Schema:
seq = select_union(*seq_selection)(schema)
features = select_union(*feature_selection)(schema)

return features - seq

return select_non_seq
8 changes: 5 additions & 3 deletions merlin/models/torch/outputs/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def defaults(block: TabularOutputBlock):
Args:
block (TabularOutputBlock): The block to initialize.
"""
block.add_route_for_each([Tags.CONTINUOUS, Tags.REGRESSION], RegressionOutput())
block.add_route_for_each(BinaryOutput.schema_selection, BinaryOutput())
block.add_route_for_each(CategoricalOutput.schema_selection, CategoricalOutput())
block.add_route_for_each([Tags.CONTINUOUS, Tags.REGRESSION], RegressionOutput(), required=False)
block.add_route_for_each(BinaryOutput.schema_selection, BinaryOutput(), required=False)
block.add_route_for_each(
CategoricalOutput.schema_selection, CategoricalOutput(), required=False
)
14 changes: 12 additions & 2 deletions merlin/models/torch/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def add_route(
selection: schema.Selection,
module: Optional[nn.Module] = None,
name: Optional[str] = None,
required: bool = True,
) -> "RouterBlock":
"""Add a new routing path for a given selection.
Expand All @@ -84,6 +85,8 @@ def add_route(
The module to append to the branch after selection.
name : str, optional
The name of the branch. Default is the name of the selection.
required : bool, optional
Whether the route is required. Default is True.
Returns
-------
Expand All @@ -96,6 +99,9 @@ def add_route(

routing_module = schema.select(self.selectable, selection)
if not routing_module:
if required:
raise ValueError(f"Selection {selection} not found in {self.selectable}")

return self

if module is not None:
Expand Down Expand Up @@ -124,7 +130,11 @@ def add_route(
return self

def add_route_for_each(
self, selection: schema.Selection, module: nn.Module, shared=False
self,
selection: schema.Selection,
module: nn.Module,
shared=False,
required: bool = True,
) -> "RouterBlock":
"""Add a new route for each column in a selection.
Expand Down Expand Up @@ -166,7 +176,7 @@ def add_route_for_each(
else:
col_module = deepcopy(module)

self.add_route(col, col_module, name=col.name)
self.add_route(col, col_module, name=col.name, required=required)

return self

Expand Down
29 changes: 29 additions & 0 deletions merlin/models/torch/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,35 @@ def select_schema(schema: Schema, selection: Selection) -> Schema:
return selected


def select_union(*selections: Selection) -> Selection:
"""
Combine selections into a single selection.
This function returns a new function `combined_select` that, when called,
will perform the union operation on all the input selections.
Parameters
----------
*selections : Selection
Variable length argument list of Selection instances.
Returns
-------
Selection
A function that takes a Schema as input and returns a Schema which
is the union of all selections.
"""

def combined_select(schema: Schema) -> Schema:
output = Schema()
for s in selections:
output += select(schema, s)

return output

return combined_select


def selection_name(selection: Selection) -> str:
"""
Get the name of the selection.
Expand Down
2 changes: 1 addition & 1 deletion merlin/models/torch/transforms/agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
_sorted_tensors = []
for tensor in sorted_tensors:
if tensor.dim() < max_dims:
_sorted_tensors.append(tensor.unsqueeze(1))
_sorted_tensors.append(tensor.unsqueeze(-1))
else:
_sorted_tensors.append(tensor)
sorted_tensors = _sorted_tensors
Expand Down
Loading

0 comments on commit 28b690c

Please sign in to comment.